-
Notifications
You must be signed in to change notification settings - Fork 34
Open
Description
We can call it some other name too instead of reusing hlo_call.
The idea is to give it an existing MLIR str and do the following:
Reactant.Ops.<hlo_call>(
"""
module @reactant_diagmm attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<1024xf32> {enzymexla.memory_effects = []}, %arg1: tensor<32x1024xf32> {enzymexla.memory_effects = []}, %arg2: tensor<32x1024xf32> {enzymexla.memory_effects = []}) -> tensor<f32> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<1024x1024xf32>
%0 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<32x1024xf32>) -> tensor<1024x32xf32>
%1 = stablehlo.iota dim = 0 : tensor<1024x2xi64>
%2 = "stablehlo.scatter"(%cst_0, %1, %arg0) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
stablehlo.return %arg4 : tensor<f32>
}) : (tensor<1024x1024xf32>, tensor<1024x2xi64>, tensor<1024xf32>) -> tensor<1024x1024xf32>
%3 = stablehlo.dot_general %2, %arg1, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<1024x1024xf32>, tensor<32x1024xf32>) -> tensor<1024x32xf32>
%4 = stablehlo.add %3, %0 : tensor<1024x32xf32>
%5 = stablehlo.reduce(%4 init: %cst) applies stablehlo.add across dimensions = [0, 1] : (tensor<1024x32xf32>, tensor<f32>) -> tensor<f32>
return %5 : tensor<f32>
}
}
"""
)This roughly gets translated into
func.func @input_hlo(....) {
....
}
func.func @main() {
%inp0 = stablehlo.constant ....
%inp1 = stablehlo.constant ....
%inp2 = stablehlo.constant ....
// We need the opt barrier to prevent const folding
%inp00, %inp10, %inp20 = stablehlo.optimization_barrier(%inp0, %inp1, %inp2)
%res = func.call @input_hlo(%inp00, %inp10, %inp20)
return %res
}This would make it easier to test optimization passes on the input mlir more convenient than having to feed in the inputs manually
Metadata
Metadata
Assignees
Labels
No labels