@@ -1558,14 +1558,18 @@ struct DotOpConversion : public OpConversionPattern<enzyme::DotOp> {
15581558 auto lhs = adaptor.getLhs ();
15591559 auto rhs = adaptor.getRhs ();
15601560 auto resultType = cast<RankedTensorType>(op.getResult ().getType ());
1561- auto lhsType = cast<RankedTensorType>(lhs.getType ());
1561+
1562+ auto lhsBatching = op.getLhsBatchingDimensions ();
1563+ auto rhsBatching = op.getRhsBatchingDimensions ();
1564+ auto lhsContracting = op.getLhsContractingDimensions ();
1565+ auto rhsContracting = op.getRhsContractingDimensions ();
15621566
15631567 auto dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr::get (
15641568 rewriter.getContext (),
1565- /* lhs_batching_dimensions= */ {} ,
1566- /* rhs_batching_dimensions= */ {} ,
1567- /* lhs_contracting_dimensions= */ { 0 } ,
1568- /* rhs_contracting_dimensions= */ { 0 } );
1569+ SmallVector< int64_t >(lhsBatching. begin (), lhsBatching. end ()) ,
1570+ SmallVector< int64_t >(rhsBatching. begin (), rhsBatching. end ()) ,
1571+ SmallVector< int64_t >(lhsContracting. begin (), lhsContracting. end ()) ,
1572+ SmallVector< int64_t >(rhsContracting. begin (), rhsContracting. end ()) );
15691573
15701574 auto dotOp = rewriter.create <stablehlo::DotGeneralOp>(
15711575 op.getLoc (), resultType, lhs, rhs, dotDimensionNumbers,
0 commit comments