Skip to content

Commit caa934b

Browse files
committed
generalize enzyme.dot
1 parent f2eb242 commit caa934b

File tree

2 files changed

+122
-124
lines changed

2 files changed

+122
-124
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,14 +1590,18 @@ struct DotOpConversion : public OpConversionPattern<enzyme::DotOp> {
15901590
auto lhs = adaptor.getLhs();
15911591
auto rhs = adaptor.getRhs();
15921592
auto resultType = cast<RankedTensorType>(op.getResult().getType());
1593-
auto lhsType = cast<RankedTensorType>(lhs.getType());
1593+
1594+
auto lhsBatching = op.getLhsBatchingDimensions();
1595+
auto rhsBatching = op.getRhsBatchingDimensions();
1596+
auto lhsContracting = op.getLhsContractingDimensions();
1597+
auto rhsContracting = op.getRhsContractingDimensions();
15941598

15951599
auto dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr::get(
15961600
rewriter.getContext(),
1597-
/*lhs_batching_dimensions=*/{},
1598-
/*rhs_batching_dimensions=*/{},
1599-
/*lhs_contracting_dimensions=*/{0},
1600-
/*rhs_contracting_dimensions=*/{0});
1601+
SmallVector<int64_t>(lhsBatching.begin(), lhsBatching.end()),
1602+
SmallVector<int64_t>(rhsBatching.begin(), rhsBatching.end()),
1603+
SmallVector<int64_t>(lhsContracting.begin(), lhsContracting.end()),
1604+
SmallVector<int64_t>(rhsContracting.begin(), rhsContracting.end()));
16011605

16021606
auto dotOp = stablehlo::DotGeneralOp::create(
16031607
rewriter, op.getLoc(), resultType, lhs, rhs, dotDimensionNumbers,

0 commit comments

Comments
 (0)