@@ -186,23 +186,23 @@ defmodule EXLA.MLIR.Value do
186
186
187
187
def reverse ( % Value { function: func } = operand , dims , typespec ) do
188
188
result_types = typespecs_to_mlir_types ( [ typespec ] )
189
- attributes = [ dimensions: attr_dense_i64_elements ( dims ) ]
189
+ attributes = [ dimensions: attr_array_i64_elements ( dims ) ]
190
190
op ( func , "stablehlo.reverse" , [ operand ] , result_types , attributes: attributes ) |> one! ( )
191
191
end
192
192
193
193
def transpose ( % Value { function: func } = operand , axes , typespec ) do
194
194
result_types = typespecs_to_mlir_types ( [ typespec ] )
195
- attributes = [ permutation: attr_dense_i64_elements ( axes ) ]
195
+ attributes = [ permutation: attr_array_i64_elements ( axes ) ]
196
196
op ( func , "stablehlo.transpose" , [ operand ] , result_types , attributes: attributes ) |> one! ( )
197
197
end
198
198
199
199
def slice ( % Value { function: func } = operand , starts , limits , strides , typespec ) do
200
200
result_types = typespecs_to_mlir_types ( [ typespec ] )
201
201
202
202
attributes = [
203
- start_indices: attr_dense_i64_elements ( starts ) ,
204
- limit_indices: attr_dense_i64_elements ( limits ) ,
205
- strides: attr_dense_i64_elements ( strides )
203
+ start_indices: attr_array_i64_elements ( starts ) ,
204
+ limit_indices: attr_array_i64_elements ( limits ) ,
205
+ strides: attr_array_i64_elements ( strides )
206
206
]
207
207
208
208
op ( func , "stablehlo.slice" , [ operand ] , result_types , attributes: attributes ) |> one! ( )
@@ -211,7 +211,7 @@ defmodule EXLA.MLIR.Value do
211
211
def dynamic_slice ( % Value { function: func } = operand , starts , lengths , typespec ) do
212
212
result_types = typespecs_to_mlir_types ( [ typespec ] )
213
213
operands = [ operand ] ++ starts
214
- attributes = [ slice_sizes: attr_dense_i64_elements ( lengths ) ]
214
+ attributes = [ slice_sizes: attr_array_i64_elements ( lengths ) ]
215
215
op ( func , "stablehlo.dynamic_slice" , operands , result_types , attributes: attributes ) |> one! ( )
216
216
end
217
217
@@ -303,7 +303,7 @@ defmodule EXLA.MLIR.Value do
303
303
result_types = typespecs_to_mlir_types ( [ typespec ] )
304
304
305
305
attributes = [
306
- broadcast_dimensions: attr_dense_i64_elements ( axes )
306
+ broadcast_dimensions: attr_array_i64_elements ( axes )
307
307
]
308
308
309
309
op ( func , "stablehlo.broadcast_in_dim" , [ operand ] , result_types , attributes: attributes )
@@ -347,9 +347,9 @@ defmodule EXLA.MLIR.Value do
347
347
{ padding_low , padding_high , padding_mid } = unzip_padding_config ( padding_config )
348
348
349
349
attributes = [
350
- edge_padding_low: attr_dense_i64_elements ( padding_low ) ,
351
- edge_padding_high: attr_dense_i64_elements ( padding_high ) ,
352
- interior_padding: attr_dense_i64_elements ( padding_mid )
350
+ edge_padding_low: attr_array_i64_elements ( padding_low ) ,
351
+ edge_padding_high: attr_array_i64_elements ( padding_high ) ,
352
+ interior_padding: attr_array_i64_elements ( padding_mid )
353
353
]
354
354
355
355
op ( func , "stablehlo.pad" , [ operand , pad ] , result_types , attributes: attributes ) |> one! ( )
@@ -375,7 +375,7 @@ defmodule EXLA.MLIR.Value do
375
375
376
376
attributes = [
377
377
fft_type: fft_type ,
378
- fft_length: attr_dense_i64_elements ( List . wrap ( fft_length ) )
378
+ fft_length: attr_array_i64_elements ( List . wrap ( fft_length ) )
379
379
]
380
380
381
381
op ( func , "stablehlo.fft" , [ value ] , result_types , attributes: attributes ) |> one! ( )
@@ -451,8 +451,8 @@ defmodule EXLA.MLIR.Value do
451
451
result_types = typespecs_to_mlir_types ( [ typespec ] )
452
452
453
453
attributes = [
454
- window_dimensions: attr_dense_i64_elements ( window_dimensions ) ,
455
- window_strides: attr_dense_i64_elements ( window_strides ) ,
454
+ window_dimensions: attr_array_i64_elements ( window_dimensions ) ,
455
+ window_strides: attr_array_i64_elements ( window_strides ) ,
456
456
padding: attr_padding ( padding )
457
457
]
458
458
@@ -501,7 +501,7 @@ defmodule EXLA.MLIR.Value do
501
501
502
502
attributes = [
503
503
dimension_numbers: dimension_numbers ,
504
- slice_sizes: attr_dense_i64_elements ( slice_sizes ) ,
504
+ slice_sizes: attr_array_i64_elements ( slice_sizes ) ,
505
505
indices_are_sorted: attr_boolean ( false )
506
506
]
507
507
@@ -546,10 +546,10 @@ defmodule EXLA.MLIR.Value do
546
546
attr_precision_config = attr_precision_config ( precision_config )
547
547
548
548
attributes = [
549
- window_strides: attr_dense_i64_elements ( strides ) ,
549
+ window_strides: attr_array_i64_elements ( strides ) ,
550
550
padding: attr_padding ( padding ) ,
551
- lhs_dilation: attr_dense_i64_elements ( input_dilation ) ,
552
- rhs_dilation: attr_dense_i64_elements ( kernel_dilation ) ,
551
+ lhs_dilation: attr_array_i64_elements ( input_dilation ) ,
552
+ rhs_dilation: attr_array_i64_elements ( kernel_dilation ) ,
553
553
dimension_numbers: attr_conv_dimension_numbers ( dimension_numbers ) ,
554
554
feature_group_count: attr_i64 ( feature_group_count ) ,
555
555
batch_group_count: attr_i64 ( batch_group_count ) ,
@@ -625,7 +625,7 @@ defmodule EXLA.MLIR.Value do
625
625
) do
626
626
operands = inputs ++ init_values
627
627
result_types = typespecs_to_mlir_types ( typespecs )
628
- attributes = [ dimensions: attr_dense_i64_elements ( dimensions ) ]
628
+ attributes = [ dimensions: attr_array_i64_elements ( dimensions ) ]
629
629
regions = [ reducer ]
630
630
op ( func , "stablehlo.reduce" , operands , result_types , attributes: attributes , regions: regions )
631
631
end
@@ -645,10 +645,10 @@ defmodule EXLA.MLIR.Value do
645
645
result_types = typespecs_to_mlir_types ( typespecs )
646
646
647
647
attributes = [
648
- window_dimensions: attr_dense_i64_elements ( window_dimensions ) ,
649
- window_strides: attr_dense_i64_elements ( window_strides ) ,
650
- base_dilations: attr_dense_i64_elements ( input_dilations ) ,
651
- window_dilations: attr_dense_i64_elements ( window_dilations ) ,
648
+ window_dimensions: attr_array_i64_elements ( window_dimensions ) ,
649
+ window_strides: attr_array_i64_elements ( window_strides ) ,
650
+ base_dilations: attr_array_i64_elements ( input_dilations ) ,
651
+ window_dilations: attr_array_i64_elements ( window_dilations ) ,
652
652
padding: attr_padding ( padding )
653
653
]
654
654
@@ -669,7 +669,7 @@ defmodule EXLA.MLIR.Value do
669
669
result_types = typespecs_to_mlir_types ( [ typespec ] )
670
670
671
671
attributes = [
672
- dimensions: attr_dense_i64_elements ( dimensions )
672
+ dimensions: attr_array_i64_elements ( dimensions )
673
673
]
674
674
675
675
regions = [ mapper ]
@@ -904,8 +904,12 @@ defmodule EXLA.MLIR.Value do
904
904
<< value :: size ( size ) - big >>
905
905
end
906
906
907
- defp attr_dense_i64_elements ( list ) do
908
- attr_dense_elements ( list , { :s , 64 } , { length ( list ) } )
907
+ defp attr_array_i64_elements ( [ ] ) do
908
+ "array<i64>"
909
+ end
910
+
911
+ defp attr_array_i64_elements ( list ) do
912
+ "array<i64: #{ Enum . join ( list , ", " ) } >"
909
913
end
910
914
911
915
defp attr_dense_elements ( [ ] , type , { 0 } = shape ) do
0 commit comments