Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9667,12 +9667,15 @@ struct PadDotGeneral
nextOtherArg = slice.getResult();
}

auto lhs = otherIsLHS ? nextOtherArg : pad.getOperand();
auto rhs = otherIsLHS ? pad.getOperand() : nextOtherArg;
if (!areValidDotGeneralInputs(lhs, rhs, op.getDotDimensionNumbersAttr()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels like the wrong solution. shouldn't we just slice the inputs properly here?

return failure();

Value res = rewriter.create<stablehlo::DotGeneralOp>(
op.getLoc(),
RankedTensorType::get(resultShape, op.getType().getElementType()),
otherIsLHS ? nextOtherArg : pad.getOperand(),
otherIsLHS ? pad.getOperand() : nextOtherArg,
op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(),
RankedTensorType::get(resultShape, op.getType().getElementType()), lhs,
rhs, op.getDotDimensionNumbersAttr(), op.getPrecisionConfigAttr(),
op.getAlgorithmAttr());

if (!resultDimsToPad.empty()) {
Expand Down
22 changes: 22 additions & 0 deletions src/enzyme_ad/jax/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,28 @@ bool reshapeIsTranspose(stablehlo::ReshapeOp reshapeOp) {
return true;
}

bool areValidDotGeneralInputs(
Value lhs, Value rhs, stablehlo::DotDimensionNumbersAttr dotDimNumbers) {
auto lhsShape = cast<ShapedType>(lhs.getType()).getShape();
auto rhsShape = cast<ShapedType>(rhs.getType()).getShape();

for (auto [lhsBatchDim, rhsBatchDim] :
llvm::zip(dotDimNumbers.getLhsBatchingDimensions(),
dotDimNumbers.getRhsBatchingDimensions())) {
if (lhsShape[lhsBatchDim] != rhsShape[rhsBatchDim])
return false;
}

for (auto [lhsContractingDim, rhsContractingDim] :
llvm::zip(dotDimNumbers.getLhsContractingDimensions(),
dotDimNumbers.getRhsContractingDimensions())) {
if (lhsShape[lhsContractingDim] != rhsShape[rhsContractingDim])
return false;
}

return true;
}

} // namespace stablehlo

} // namespace mlir
3 changes: 3 additions & 0 deletions src/enzyme_ad/jax/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,9 @@ negatedComparisonDirection(stablehlo::ComparisonDirection direction);

bool reshapeIsTranspose(stablehlo::ReshapeOp reshapeOp);

bool areValidDotGeneralInputs(Value lhs, Value rhs,
stablehlo::DotDimensionNumbersAttr dotDimNumbers);

} // namespace stablehlo

} // namespace mlir
17 changes: 17 additions & 0 deletions test/lit_tests/paddotgeneral2.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: enzymexlamlir-opt --pass-pipeline="any(enzyme-hlo-generate-td{patterns=pad_dot_general<1>(1)},transform-interpreter,enzyme-hlo-remove-transform)" %s | FileCheck %s

func.func @fn(%arg0: tensor<f32>) -> (tensor<32x32xf32>, tensor<f32>) {
%c = stablehlo.constant dense<3> : tensor<32x32xi64>
%c_0 = stablehlo.constant dense<2> : tensor<32x32xi64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.reshape %arg0 : (tensor<f32>) -> tensor<1x1xf32>
%1 = stablehlo.pad %0, %cst, low = [2, 3], high = [29, 28], interior = [0, 0] : (tensor<1x1xf32>, tensor<f32>) -> tensor<32x32xf32>
%2 = stablehlo.convert %c_0 : (tensor<32x32xi64>) -> tensor<32x32xf32>
%3 = stablehlo.multiply %2, %1 : tensor<32x32xf32>
%4 = stablehlo.convert %c : (tensor<32x32xi64>) -> tensor<32x32xf32>
%5 = stablehlo.subtract %3, %4 : tensor<32x32xf32>
%6 = stablehlo.pad %0, %cst, low = [3, 2], high = [28, 29], interior = [0, 0] : (tensor<1x1xf32>, tensor<f32>) -> tensor<32x32xf32>
%7 = stablehlo.dot_general %6, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
// CHECK: %7 = stablehlo.dot_general %6, %5, contracting_dims = [0] x [1], precision = [DEFAULT, DEFAULT] : (tensor<32x32xf32>, tensor<32x32xf32>) -> tensor<32x32xf32>
return %7, %arg0 : tensor<32x32xf32>, tensor<f32>
}
Loading