Skip to content

Commit 681189b

Browse files
committed
generalize enzyme.dot
1 parent c3cafd5 commit 681189b

File tree

2 files changed

+123
-124
lines changed

2 files changed

+123
-124
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,14 +1558,18 @@ struct DotOpConversion : public OpConversionPattern<enzyme::DotOp> {
15581558
auto lhs = adaptor.getLhs();
15591559
auto rhs = adaptor.getRhs();
15601560
auto resultType = cast<RankedTensorType>(op.getResult().getType());
1561-
auto lhsType = cast<RankedTensorType>(lhs.getType());
1561+
1562+
auto lhsBatching = op.getLhsBatchingDimensions();
1563+
auto rhsBatching = op.getRhsBatchingDimensions();
1564+
auto lhsContracting = op.getLhsContractingDimensions();
1565+
auto rhsContracting = op.getRhsContractingDimensions();
15621566

15631567
auto dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr::get(
15641568
rewriter.getContext(),
1565-
/*lhs_batching_dimensions=*/{},
1566-
/*rhs_batching_dimensions=*/{},
1567-
/*lhs_contracting_dimensions=*/{0},
1568-
/*rhs_contracting_dimensions=*/{0});
1569+
SmallVector<int64_t>(lhsBatching.begin(), lhsBatching.end()),
1570+
SmallVector<int64_t>(rhsBatching.begin(), rhsBatching.end()),
1571+
SmallVector<int64_t>(lhsContracting.begin(), lhsContracting.end()),
1572+
SmallVector<int64_t>(rhsContracting.begin(), rhsContracting.end()));
15691573

15701574
auto dotOp = rewriter.create<stablehlo::DotGeneralOp>(
15711575
op.getLoc(), resultType, lhs, rhs, dotDimensionNumbers,

0 commit comments

Comments
 (0)