Skip to content

Commit 1936050

Browse files
committed
insert right after operand
1 parent e17821b commit 1936050

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30395,6 +30395,7 @@ struct RecognizeMultiRotate
3039530395
int32_t totalResults = leftAmount + rightAmount + 1;
3039630396

3039730397
// Create the MultiRotateOp
30398+
rewriter.setInsertionPointAfterValue(input);
3039830399
auto newOp = rewriter.create<enzymexla::MultiRotateOp>(
3039930400
op.getLoc(),
3040030401
SmallVector<Type>(totalResults, input.getType()),

test/lit_tests/recognize_multirotate.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ func.func @partial_combination(%arg0: tensor<10x20xf32>) -> (tensor<10x20xf32>,
149149

150150
// CHECK-LABEL: func.func @partial_combination(
151151
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: tensor<10x20xf32>) -> (tensor<10x20xf32>, tensor<10x20xf32>, tensor<10x20xf32>) {
152-
// CHECK: %[[VAL_1:.*]] = "enzymexla.rotate"(%[[VAL_0]]) <{amount = -2 : si32, dimension = 0 : si32}> : (tensor<10x20xf32>) -> tensor<10x20xf32>
153152
// CHECK: %[[VAL_2:.*]]:3 = "enzymexla.multi_rotate"(%[[VAL_0]]) <{dimension = 0 : si32, left_amount = 2 : si32, right_amount = 0 : si32}> : (tensor<10x20xf32>) -> (tensor<10x20xf32>, tensor<10x20xf32>, tensor<10x20xf32>)
153+
// CHECK: %[[VAL_1:.*]] = "enzymexla.rotate"(%[[VAL_0]]) <{amount = -2 : si32, dimension = 0 : si32}> : (tensor<10x20xf32>) -> tensor<10x20xf32>
154154
// CHECK: return %[[VAL_1]], %[[VAL_2]]#1, %[[VAL_2]]#0 : tensor<10x20xf32>, tensor<10x20xf32>, tensor<10x20xf32>
155155
// CHECK: }
156156

0 commit comments

Comments
 (0)