Skip to content

Commit b82de62

Browse files
committed
fix: add a separate check for pad_dot_general
1 parent 9b689ed commit b82de62

File tree

4 files changed

+49
-4
lines changed

4 files changed

+49
-4
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9667,12 +9667,15 @@ struct PadDotGeneral
96679667
nextOtherArg = slice.getResult();
96689668
}
96699669

9670+
auto lhs = otherIsLHS ? nextOtherArg : pad.getOperand();
9671+
auto rhs = otherIsLHS ? pad.getOperand() : nextOtherArg;
9672+
if (!areValidDotGeneralInputs(lhs, rhs, op.getDotDimensionNumbersAttr()))
9673+
return failure();
9674+
96709675
Value res = rewriter.create<stablehlo::DotGeneralOp>(
96719676
op.getLoc(),
9672-
RankedTensorType::get(resultShape, op.getType().getElementType()),
9673-
otherIsLHS ? nextOtherArg : pad.getOperand(),
9674-
otherIsLHS ? pad.getOperand() : nextOtherArg,
9675-
op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(),
9677+
RankedTensorType::get(resultShape, op.getType().getElementType()), lhs,
9678+
rhs, op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(),
96769679
op.getAlgorithmAttr());
96779680

96789681
if (!resultDimsToPad.empty()) {

src/enzyme_ad/jax/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,28 @@ bool reshapeIsTranspose(stablehlo::ReshapeOp reshapeOp) {
12481248
return true;
12491249
}
12501250

1251+
bool areValidDotGeneralInputs(Value lhs, Value rhs,
1252+
stablehlo::DotDimensionNumbersAttr dotDimNumbers) {
1253+
auto lhsShape = cast<ShapedType>(lhs.getType()).getShape();
1254+
auto rhsShape = cast<ShapedType>(rhs.getType()).getShape();
1255+
1256+
for (auto [lhsBatchDim, rhsBatchDim] :
1257+
llvm::zip(dotDimNumbers.getLhsBatchingDimensions(),
1258+
dotDimNumbers.getRhsBatchingDimensions())) {
1259+
if (lhsShape[lhsBatchDim] != rhsShape[rhsBatchDim])
1260+
return false;
1261+
}
1262+
1263+
for (auto [lhsContractingDim, rhsContractingDim] :
1264+
llvm::zip(dotDimNumbers.getLhsContractingDimensions(),
1265+
dotDimNumbers.getRhsContractingDimensions())) {
1266+
if (lhsShape[lhsContractingDim] != rhsShape[rhsContractingDim])
1267+
return false;
1268+
}
1269+
1270+
return true;
1271+
}
1272+
12511273
} // namespace stablehlo
12521274

12531275
} // namespace mlir

src/enzyme_ad/jax/Utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,9 @@ negatedComparisonDirection(stablehlo::ComparisonDirection direction);
716716

717717
bool reshapeIsTranspose(stablehlo::ReshapeOp reshapeOp);
718718

719+
bool areValidDotGeneralInputs(Value lhs, Value rhs,
720+
stablehlo::DotDimensionNumbersAttr dotDimNumbers);
721+
719722
} // namespace stablehlo
720723

721724
} // namespace mlir

test/lit_tests/paddotgeneral2.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// RUN: enzymexlamlir-opt --pass-pipeline="any(enzyme-hlo-generate-td{patterns=pad_dot_general<1>(1)},transform-interpreter,enzyme-hlo-remove-transform)" %s | FileCheck %s
2+
3+
func.func @fn(%arg0: tensor<f32>) -> (tensor<32x32xf32>, tensor<f32>) {
4+
%c = stablehlo.constant dense<3> : tensor<32x32xi64>
5+
%c_0 = stablehlo.constant dense<2> : tensor<32x32xi64>
6+
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
7+
%0 = stablehlo.reshape %arg0 : (tensor<f32>) -> tensor<1x1xf32>
8+
%1 = stablehlo.pad %0, %cst, low = [2, 3], high = [29, 28], interior = [0, 0] : (tensor<1x1xf32>, tensor<f32>) -> tensor<32x32xf32>
9+
%2 = stablehlo.convert %c_0 : (tensor<32x32xi64>) -> tensor<32x32xf32>
10+
%3 = stablehlo.multiply %2, %1 : tensor<32x32xf32>
11+
%4 = stablehlo.convert %c : (tensor<32x32xi64>) -> tensor<32x32xf32>
12+
%5 = stablehlo.subtract %3, %4 : tensor<32x32xf32>
13+
%6 = stablehlo.pad %0, %cst, low = [3, 2], high = [28, 29], interior = [0, 0] : (tensor<1x1xf32>, tensor<f32>) -> tensor<32x32xf32>
14+
%7 = stablehlo.dot_general %6, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
15+
// CHECK: %7 = stablehlo.dot_general %6, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
16+
return %7, %arg0 : tensor<32x32xf32>, tensor<f32>
17+
}

0 commit comments

Comments
 (0)