Skip to content

Commit 741829f

Browse files
authored
feat: update MLIR notation to latest stablehlo spec" (#1488)
1 parent e26112d commit 741829f

File tree

2 files changed

+31
-25
lines changed

2 files changed

+31
-25
lines changed

exla/lib/exla/mlir/module.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ defmodule EXLA.MLIR.Module do
9898
do: -1,
9999
else: Keyword.get(options, :device_id, client.default_device_id)
100100

101+
# module.ref |> EXLA.NIF.mlir_module_to_string() |> elem(1) |> IO.puts()
102+
101103
ref =
102104
EXLA.NIF.mlir_compile(
103105
client.ref,

exla/lib/exla/mlir/value.ex

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -186,23 +186,23 @@ defmodule EXLA.MLIR.Value do
186186

187187
def reverse(%Value{function: func} = operand, dims, typespec) do
188188
result_types = typespecs_to_mlir_types([typespec])
189-
attributes = [dimensions: attr_dense_i64_elements(dims)]
189+
attributes = [dimensions: attr_array_i64_elements(dims)]
190190
op(func, "stablehlo.reverse", [operand], result_types, attributes: attributes) |> one!()
191191
end
192192

193193
def transpose(%Value{function: func} = operand, axes, typespec) do
194194
result_types = typespecs_to_mlir_types([typespec])
195-
attributes = [permutation: attr_dense_i64_elements(axes)]
195+
attributes = [permutation: attr_array_i64_elements(axes)]
196196
op(func, "stablehlo.transpose", [operand], result_types, attributes: attributes) |> one!()
197197
end
198198

199199
def slice(%Value{function: func} = operand, starts, limits, strides, typespec) do
200200
result_types = typespecs_to_mlir_types([typespec])
201201

202202
attributes = [
203-
start_indices: attr_dense_i64_elements(starts),
204-
limit_indices: attr_dense_i64_elements(limits),
205-
strides: attr_dense_i64_elements(strides)
203+
start_indices: attr_array_i64_elements(starts),
204+
limit_indices: attr_array_i64_elements(limits),
205+
strides: attr_array_i64_elements(strides)
206206
]
207207

208208
op(func, "stablehlo.slice", [operand], result_types, attributes: attributes) |> one!()
@@ -211,7 +211,7 @@ defmodule EXLA.MLIR.Value do
211211
def dynamic_slice(%Value{function: func} = operand, starts, lengths, typespec) do
212212
result_types = typespecs_to_mlir_types([typespec])
213213
operands = [operand] ++ starts
214-
attributes = [slice_sizes: attr_dense_i64_elements(lengths)]
214+
attributes = [slice_sizes: attr_array_i64_elements(lengths)]
215215
op(func, "stablehlo.dynamic_slice", operands, result_types, attributes: attributes) |> one!()
216216
end
217217

@@ -303,7 +303,7 @@ defmodule EXLA.MLIR.Value do
303303
result_types = typespecs_to_mlir_types([typespec])
304304

305305
attributes = [
306-
broadcast_dimensions: attr_dense_i64_elements(axes)
306+
broadcast_dimensions: attr_array_i64_elements(axes)
307307
]
308308

309309
op(func, "stablehlo.broadcast_in_dim", [operand], result_types, attributes: attributes)
@@ -347,9 +347,9 @@ defmodule EXLA.MLIR.Value do
347347
{padding_low, padding_high, padding_mid} = unzip_padding_config(padding_config)
348348

349349
attributes = [
350-
edge_padding_low: attr_dense_i64_elements(padding_low),
351-
edge_padding_high: attr_dense_i64_elements(padding_high),
352-
interior_padding: attr_dense_i64_elements(padding_mid)
350+
edge_padding_low: attr_array_i64_elements(padding_low),
351+
edge_padding_high: attr_array_i64_elements(padding_high),
352+
interior_padding: attr_array_i64_elements(padding_mid)
353353
]
354354

355355
op(func, "stablehlo.pad", [operand, pad], result_types, attributes: attributes) |> one!()
@@ -375,7 +375,7 @@ defmodule EXLA.MLIR.Value do
375375

376376
attributes = [
377377
fft_type: fft_type,
378-
fft_length: attr_dense_i64_elements(List.wrap(fft_length))
378+
fft_length: attr_array_i64_elements(List.wrap(fft_length))
379379
]
380380

381381
op(func, "stablehlo.fft", [value], result_types, attributes: attributes) |> one!()
@@ -451,8 +451,8 @@ defmodule EXLA.MLIR.Value do
451451
result_types = typespecs_to_mlir_types([typespec])
452452

453453
attributes = [
454-
window_dimensions: attr_dense_i64_elements(window_dimensions),
455-
window_strides: attr_dense_i64_elements(window_strides),
454+
window_dimensions: attr_array_i64_elements(window_dimensions),
455+
window_strides: attr_array_i64_elements(window_strides),
456456
padding: attr_padding(padding)
457457
]
458458

@@ -501,7 +501,7 @@ defmodule EXLA.MLIR.Value do
501501

502502
attributes = [
503503
dimension_numbers: dimension_numbers,
504-
slice_sizes: attr_dense_i64_elements(slice_sizes),
504+
slice_sizes: attr_array_i64_elements(slice_sizes),
505505
indices_are_sorted: attr_boolean(false)
506506
]
507507

@@ -546,10 +546,10 @@ defmodule EXLA.MLIR.Value do
546546
attr_precision_config = attr_precision_config(precision_config)
547547

548548
attributes = [
549-
window_strides: attr_dense_i64_elements(strides),
549+
window_strides: attr_array_i64_elements(strides),
550550
padding: attr_padding(padding),
551-
lhs_dilation: attr_dense_i64_elements(input_dilation),
552-
rhs_dilation: attr_dense_i64_elements(kernel_dilation),
551+
lhs_dilation: attr_array_i64_elements(input_dilation),
552+
rhs_dilation: attr_array_i64_elements(kernel_dilation),
553553
dimension_numbers: attr_conv_dimension_numbers(dimension_numbers),
554554
feature_group_count: attr_i64(feature_group_count),
555555
batch_group_count: attr_i64(batch_group_count),
@@ -625,7 +625,7 @@ defmodule EXLA.MLIR.Value do
625625
) do
626626
operands = inputs ++ init_values
627627
result_types = typespecs_to_mlir_types(typespecs)
628-
attributes = [dimensions: attr_dense_i64_elements(dimensions)]
628+
attributes = [dimensions: attr_array_i64_elements(dimensions)]
629629
regions = [reducer]
630630
op(func, "stablehlo.reduce", operands, result_types, attributes: attributes, regions: regions)
631631
end
@@ -645,10 +645,10 @@ defmodule EXLA.MLIR.Value do
645645
result_types = typespecs_to_mlir_types(typespecs)
646646

647647
attributes = [
648-
window_dimensions: attr_dense_i64_elements(window_dimensions),
649-
window_strides: attr_dense_i64_elements(window_strides),
650-
base_dilations: attr_dense_i64_elements(input_dilations),
651-
window_dilations: attr_dense_i64_elements(window_dilations),
648+
window_dimensions: attr_array_i64_elements(window_dimensions),
649+
window_strides: attr_array_i64_elements(window_strides),
650+
base_dilations: attr_array_i64_elements(input_dilations),
651+
window_dilations: attr_array_i64_elements(window_dilations),
652652
padding: attr_padding(padding)
653653
]
654654

@@ -669,7 +669,7 @@ defmodule EXLA.MLIR.Value do
669669
result_types = typespecs_to_mlir_types([typespec])
670670

671671
attributes = [
672-
dimensions: attr_dense_i64_elements(dimensions)
672+
dimensions: attr_array_i64_elements(dimensions)
673673
]
674674

675675
regions = [mapper]
@@ -904,8 +904,12 @@ defmodule EXLA.MLIR.Value do
904904
<<value::size(size)-big>>
905905
end
906906

907-
defp attr_dense_i64_elements(list) do
908-
attr_dense_elements(list, {:s, 64}, {length(list)})
907+
defp attr_array_i64_elements([]) do
908+
"array<i64>"
909+
end
910+
911+
defp attr_array_i64_elements(list) do
912+
"array<i64: #{Enum.join(list, ", ")}>"
909913
end
910914

911915
defp attr_dense_elements([], type, {0} = shape) do

0 commit comments

Comments
 (0)