Skip to content

Commit 7872c77

Browse files
authored
feat: support convert in unary scatter (#1039)
1 parent 9fda595 commit 7872c77

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19642,12 +19642,6 @@ struct UnaryElementwiseScatterSimplify final
1964219642
auto elemType =
1964319643
cast<RankedTensorType>(op->getResult(0).getType()).getElementType();
1964419644

19645-
// TODO: support convert op. we need to rewrite the update computation to
19646-
// take the converted element type
19647-
if (isa<stablehlo::ConvertOp>(op))
19648-
return rewriter.notifyMatchFailure(op,
19649-
"ConvertOp not supported for now.");
19650-
1965119645
// should get constant propagated
1965219646
auto scatterInputElem = rewriter.create(
1965319647
op->getLoc(), op->getName().getIdentifier(), ValueRange(scatterInput),
@@ -19676,8 +19670,15 @@ struct UnaryElementwiseScatterSimplify final
1967619670
ValueRange(scatterUpdatesElem->getResult(0)),
1967719671
scatterOp.getScatterDimensionNumbersAttr(),
1967819672
scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr());
19679-
newScatterOp.getUpdateComputation().takeBody(
19680-
scatterOp.getUpdateComputation());
19673+
19674+
auto &updateRegion = newScatterOp.getUpdateComputation();
19675+
auto *block = rewriter.createBlock(&updateRegion);
19676+
auto argType = RankedTensorType::get({}, elemType);
19677+
block->addArgument(argType, op->getLoc());
19678+
block->addArgument(argType, op->getLoc());
19679+
rewriter.setInsertionPointToStart(block);
19680+
rewriter.create<stablehlo::ReturnOp>(op->getLoc(), block->getArgument(1));
19681+
1968119682
rewriter.replaceOp(op, newScatterOp->getResult(0));
1968219683
return success();
1968319684
}

test/lit_tests/unaryscatter.mlir

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func.func @unaryscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tenso
3737
// CHECK-NEXT: return %7 : tensor<1024x1024xf32>
3838
// CHECK-NEXT: }
3939

40-
func.func @convertscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
40+
func.func @expscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
4141
%cst = stablehlo.constant dense<2.000000e+00> : tensor<24xf32>
4242
%c = stablehlo.constant dense<1> : tensor<24x2xi64>
4343
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<1024x1024xf32>
@@ -56,7 +56,7 @@ func.func @convertscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: ten
5656
return %8 : tensor<1024x1024xf32>
5757
}
5858

59-
// CHECK: func.func @convertscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
59+
// CHECK: func.func @expscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
6060
// CHECK-NEXT: %cst = stablehlo.constant dense<7.3890562> : tensor<24xf32>
6161
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024xf32>
6262
// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor<24x2xi64>
@@ -73,3 +73,49 @@ func.func @convertscatter(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: ten
7373
// CHECK-NEXT: %7 = stablehlo.transpose %6, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
7474
// CHECK-NEXT: return %7 : tensor<1024x1024xf32>
7575
// CHECK-NEXT: }
76+
77+
func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tensor<5x4xf32> {
78+
%c = stablehlo.constant dense<[[4, 5], [4, 5], [4, 5], [4, 5], [4, 5]]> : tensor<5x2xi64>
79+
%c_0 = stablehlo.constant dense<[-1, 3, 7, 11, 15]> : tensor<5xi64>
80+
%c_1 = stablehlo.constant dense<true> : tensor<5xi1>
81+
%c_2 = stablehlo.constant dense<4> : tensor<5xi64>
82+
%c_3 = stablehlo.constant dense<false> : tensor<4x5xi1>
83+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x4xf32>) -> tensor<4x5xf32>
84+
%1 = stablehlo.convert %arg1 : (tensor<5xui32>) -> tensor<5xi64>
85+
%2 = stablehlo.add %1, %c_0 : tensor<5xi64>
86+
%3 = stablehlo.divide %2, %c_2 : tensor<5xi64>
87+
%4 = stablehlo.reshape %2 : (tensor<5xi64>) -> tensor<5x1xi64>
88+
%5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64>
89+
%6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
90+
%7 = stablehlo.remainder %6, %c : tensor<5x2xi64>
91+
%8 = "stablehlo.scatter"(%c_3, %7, %c_1) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
92+
^bb0(%arg2: tensor<i1>, %arg3: tensor<i1>):
93+
stablehlo.return %arg3 : tensor<i1>
94+
}) : (tensor<4x5xi1>, tensor<5x2xi64>, tensor<5xi1>) -> tensor<4x5xi1>
95+
%9 = stablehlo.convert %8 : (tensor<4x5xi1>) -> tensor<4x5xf32>
96+
%10 = stablehlo.multiply %0, %9 : tensor<4x5xf32>
97+
%11 = stablehlo.transpose %10, dims = [1, 0] : (tensor<4x5xf32>) -> tensor<5x4xf32>
98+
return %11 : tensor<5x4xf32>
99+
}
100+
101+
// CHECK: func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tensor<5x4xf32> {
102+
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32>
103+
// CHECK-NEXT{LITERAL}: %c = stablehlo.constant dense<[[4, 5], [4, 5], [4, 5], [4, 5], [4, 5]]> : tensor<5x2xi64>
104+
// CHECK-NEXT: %c_0 = stablehlo.constant dense<[-1, 3, 7, 11, 15]> : tensor<5xi64>
105+
// CHECK-NEXT: %c_1 = stablehlo.constant dense<4> : tensor<5xi64>
106+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<5x4xf32>) -> tensor<4x5xf32>
107+
// CHECK-NEXT: %1 = stablehlo.convert %arg1 : (tensor<5xui32>) -> tensor<5xi64>
108+
// CHECK-NEXT: %2 = stablehlo.add %1, %c_0 : tensor<5xi64>
109+
// CHECK-NEXT: %3 = stablehlo.divide %2, %c_1 : tensor<5xi64>
110+
// CHECK-NEXT: %4 = stablehlo.reshape %2 : (tensor<5xi64>) -> tensor<5x1xi64>
111+
// CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64>
112+
// CHECK-NEXT: %6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64>
113+
// CHECK-NEXT: %7 = stablehlo.remainder %6, %c : tensor<5x2xi64>
114+
// CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, slice_sizes = array<i64: 1, 1>}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32>
115+
// CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>}> ({
116+
// CHECK-NEXT: ^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
117+
// CHECK-NEXT: stablehlo.return %arg3 : tensor<f32>
118+
// CHECK-NEXT: }) : (tensor<4x5xf32>, tensor<5x2xi64>, tensor<5xf32>) -> tensor<4x5xf32>
119+
// CHECK-NEXT: %10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<4x5xf32>) -> tensor<5x4xf32>
120+
// CHECK-NEXT: return %10 : tensor<5x4xf32>
121+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)