diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index e7a1d1cd1..9e111fb1d 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -9667,12 +9667,15 @@ struct PadDotGeneral nextOtherArg = slice.getResult(); } + auto lhs = otherIsLHS ? nextOtherArg : pad.getOperand(); + auto rhs = otherIsLHS ? pad.getOperand() : nextOtherArg; + if (!areValidDotGeneralInputs(lhs, rhs, op.getDotDimensionNumbersAttr())) + return failure(); + Value res = rewriter.create( op.getLoc(), - RankedTensorType::get(resultShape, op.getType().getElementType()), - otherIsLHS ? nextOtherArg : pad.getOperand(), - otherIsLHS ? pad.getOperand() : nextOtherArg, - op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(), + RankedTensorType::get(resultShape, op.getType().getElementType()), lhs, + rhs, op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(), op.getAlgorithmAttr()); if (!resultDimsToPad.empty()) { diff --git a/src/enzyme_ad/jax/Utils.cpp b/src/enzyme_ad/jax/Utils.cpp index 8c8136f45..ac718a863 100644 --- a/src/enzyme_ad/jax/Utils.cpp +++ b/src/enzyme_ad/jax/Utils.cpp @@ -1248,6 +1248,28 @@ bool reshapeIsTranspose(stablehlo::ReshapeOp reshapeOp) { return true; } +bool areValidDotGeneralInputs( + Value lhs, Value rhs, stablehlo::DotDimensionNumbersAttr dotDimNumbers) { + auto lhsShape = cast(lhs.getType()).getShape(); + auto rhsShape = cast(rhs.getType()).getShape(); + + for (auto [lhsBatchDim, rhsBatchDim] : + llvm::zip(dotDimNumbers.getLhsBatchingDimensions(), + dotDimNumbers.getRhsBatchingDimensions())) { + if (lhsShape[lhsBatchDim] != rhsShape[rhsBatchDim]) + return false; + } + + for (auto [lhsContractingDim, rhsContractingDim] : + llvm::zip(dotDimNumbers.getLhsContractingDimensions(), + dotDimNumbers.getRhsContractingDimensions())) { + if (lhsShape[lhsContractingDim] != rhsShape[rhsContractingDim]) + return false; + } + + return true; +} + } // namespace stablehlo } // namespace mlir diff --git a/src/enzyme_ad/jax/Utils.h b/src/enzyme_ad/jax/Utils.h index c6a05db26..b7667f79d 100644 --- a/src/enzyme_ad/jax/Utils.h +++ b/src/enzyme_ad/jax/Utils.h @@ -716,6 +716,9 @@ negatedComparisonDirection(stablehlo::ComparisonDirection direction); bool reshapeIsTranspose(stablehlo::ReshapeOp reshapeOp); +bool areValidDotGeneralInputs(Value lhs, Value rhs, + stablehlo::DotDimensionNumbersAttr dotDimNumbers); + } // namespace stablehlo } // namespace mlir diff --git a/test/lit_tests/paddotgeneral2.mlir b/test/lit_tests/paddotgeneral2.mlir new file mode 100644 index 000000000..17b385c90 --- /dev/null +++ b/test/lit_tests/paddotgeneral2.mlir @@ -0,0 +1,17 @@ +// RUN: enzymexlamlir-opt --pass-pipeline="any(enzyme-hlo-generate-td{patterns=pad_dot_general<1>(1)},transform-interpreter,enzyme-hlo-remove-transform)" %s | FileCheck %s + +func.func @fn(%arg0: tensor) -> (tensor<32x32xf32>, tensor) { + %c = stablehlo.constant dense<3> : tensor<32x32xi64> + %c_0 = stablehlo.constant dense<2> : tensor<32x32xi64> + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.reshape %arg0 : (tensor) -> tensor<1x1xf32> + %1 = stablehlo.pad %0, %cst, low = [2, 3], high = [29, 28], interior = [0, 0] : (tensor<1x1xf32>, tensor) -> tensor<32x32xf32> + %2 = stablehlo.convert %c_0 : (tensor<32x32xi64>) -> tensor<32x32xf32> + %3 = stablehlo.multiply %2, %1 : tensor<32x32xf32> + %4 = stablehlo.convert %c : (tensor<32x32xi64>) -> tensor<32x32xf32> + %5 = stablehlo.subtract %3, %4 : tensor<32x32xf32> + %6 = stablehlo.pad %0, %cst, low = [3, 2], high = [28, 29], interior = [0, 0] : (tensor<1x1xf32>, tensor) -> tensor<32x32xf32> + %7 = stablehlo.dot_general %6, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + // CHECK: %7 = stablehlo.dot_general %6, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32> + return %7, %arg0 : tensor<32x32xf32>, tensor +}