Skip to content

Commit b3da89c

Browse files
authored
fix (#2032)
1 parent 2685ba6 commit b3da89c

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26107,6 +26107,9 @@ struct ConcatenateBroadcastInDim
2610726107
return failure();
2610826108
if (bcastInDimDimensions[input_dim] != op.getDimension())
2610926109
return failure();
26110+
if (broadcastInDimOp.getOperand().getType().getShape()[input_dim] !=
26111+
broadcastInDimOp.getType().getShape()[op.getDimension()])
26112+
return failure();
2611026113
operandOperands.push_back(broadcastInDimOp.getOperand());
2611126114
continue;
2611226115
}

test/lit_tests/concatenate_bcastindim.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,17 @@ func.func @main2(%arg0: tensor<2xi1>, %arg1: tensor<2xi1>) -> (tensor<3x4xi1>) {
2828
// CHECK-NEXT: %1 = stablehlo.broadcast_in_dim %0, dims = [1] : (tensor<4xi1>) -> tensor<3x4xi1>
2929
// CHECK-NEXT: return %1 : tensor<3x4xi1>
3030
// CHECK-NEXT: }
31+
32+
func.func @main3(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<8xf32> {
33+
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0] : (tensor<1xf32>) -> tensor<4xf32>
34+
%1 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<1xf32>) -> tensor<4xf32>
35+
%2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<4xf32>, tensor<4xf32>) -> tensor<8xf32>
36+
return %2 : tensor<8xf32>
37+
}
38+
39+
// CHECK: func.func @main3(%[[ARG0:.+]]: tensor<1xf32>, %[[ARG1:.+]]: tensor<1xf32>) -> tensor<8xf32> {
40+
// CHECK-NEXT: %[[B0:.+]] = stablehlo.broadcast_in_dim %[[ARG0]], dims = [0] : (tensor<1xf32>) -> tensor<4xf32>
41+
// CHECK-NEXT: %[[B1:.+]] = stablehlo.broadcast_in_dim %[[ARG1]], dims = [0] : (tensor<1xf32>) -> tensor<4xf32>
42+
// CHECK-NEXT: %[[CAT:.+]] = stablehlo.concatenate %[[B0]], %[[B1]], dim = 0 : (tensor<4xf32>, tensor<4xf32>) -> tensor<8xf32>
43+
// CHECK-NEXT: return %[[CAT]] : tensor<8xf32>
44+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)