@@ -1590,14 +1590,18 @@ struct DotOpConversion : public OpConversionPattern<enzyme::DotOp> {
15901590 auto lhs = adaptor.getLhs ();
15911591 auto rhs = adaptor.getRhs ();
15921592 auto resultType = cast<RankedTensorType>(op.getResult ().getType ());
1593- auto lhsType = cast<RankedTensorType>(lhs.getType ());
1593+
1594+ auto lhsBatching = op.getLhsBatchingDimensions ();
1595+ auto rhsBatching = op.getRhsBatchingDimensions ();
1596+ auto lhsContracting = op.getLhsContractingDimensions ();
1597+ auto rhsContracting = op.getRhsContractingDimensions ();
15941598
15951599 auto dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr::get (
15961600 rewriter.getContext (),
1597- /* lhs_batching_dimensions= */ {} ,
1598- /* rhs_batching_dimensions= */ {} ,
1599- /* lhs_contracting_dimensions= */ { 0 } ,
1600- /* rhs_contracting_dimensions= */ { 0 } );
1601+ SmallVector< int64_t >(lhsBatching. begin (), lhsBatching. end ()) ,
1602+ SmallVector< int64_t >(rhsBatching. begin (), rhsBatching. end ()) ,
1603+ SmallVector< int64_t >(lhsContracting. begin (), lhsContracting. end ()) ,
1604+ SmallVector< int64_t >(rhsContracting. begin (), rhsContracting. end ()) );
16011605
16021606 auto dotOp = stablehlo::DotGeneralOp::create (
16031607 rewriter, op.getLoc (), resultType, lhs, rhs, dotDimensionNumbers,
0 commit comments