@@ -90,43 +90,58 @@ static Value createLinalgBodyCalculationForElementwiseOp(
9090 }
9191
9292 // tosa::MulOp
93- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94- return rewriter.create <arith::MulFOp>(loc, resultTypes, args);
95-
96- if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
97- Value a = args[0 ];
98- Value b = args[1 ];
99- auto shift =
100- cast<IntegerAttr>(op->getAttr (" shift" )).getValue ().getSExtValue ();
101- if (shift > 0 ) {
102- auto shiftConst =
103- rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
104- if (!a.getType ().isInteger (32 ))
105- a = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), a);
106-
107- if (!b.getType ().isInteger (32 ))
108- b = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), b);
109-
110- auto result = rewriter.create <tosa::ApplyScaleOp>(
111- loc, rewriter.getI32Type (), a, b, shiftConst,
112- rewriter.getBoolAttr (false ));
113-
114- if (elementTy.isInteger (32 ))
115- return result;
116-
117- return rewriter.create <arith::TruncIOp>(loc, elementTy, result);
93+ if (isa<tosa::MulOp>(op)) {
94+ auto shift_val = cast<tosa::MulOp>(op).getShift ();
95+ if (!elementTy.isInteger (32 ) && shift_val.getImpl ()) {
96+ (void )rewriter.notifyMatchFailure (op,
97+ " Cannot have shift value for non i32 output" );
98+ return nullptr ;
99+ };
100+
101+ if (isa<FloatType>(elementTy)) {
102+ return rewriter.create <arith::MulFOp>(loc, resultTypes, args[0 ], args[1 ]);
118103 }
119104
120- int aWidth = a.getType ().getIntOrFloatBitWidth ();
121- int bWidth = b.getType ().getIntOrFloatBitWidth ();
122- int cWidth = resultTypes[0 ].getIntOrFloatBitWidth ();
105+ if (isa<IntegerType>(elementTy)) {
106+ int32_t shift = 0 ;
107+ ElementsAttr shift_elem;
108+ if (shift_val.getImpl () && matchPattern (shift_val, m_Constant (&shift_elem))) {
109+ // Explicit shift is set.
110+ shift = shift_elem.getValues <IntegerAttr>()[0 ].getInt ();
111+ }
112+
113+ Value a = args[0 ];
114+ Value b = args[1 ];
115+ if (shift > 0 ) {
116+ auto shiftConst =
117+ rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
118+ if (!a.getType ().isInteger (32 ))
119+ a = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), a);
120+
121+ if (!b.getType ().isInteger (32 ))
122+ b = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), b);
123+
124+ auto result = rewriter.create <tosa::ApplyScaleOp>(
125+ loc, rewriter.getI32Type (), a, b, shiftConst,
126+ rewriter.getBoolAttr (false ));
123127
124- if (aWidth < cWidth)
125- a = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], a);
126- if (bWidth < cWidth)
127- b = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], b);
128+ if (elementTy.isInteger (32 ))
129+ return result;
128130
129- return rewriter.create <arith::MulIOp>(loc, resultTypes, a, b);
131+ return rewriter.create <arith::TruncIOp>(loc, elementTy, result);
132+ }
133+
134+ int aWidth = a.getType ().getIntOrFloatBitWidth ();
135+ int bWidth = b.getType ().getIntOrFloatBitWidth ();
136+ int cWidth = resultTypes[0 ].getIntOrFloatBitWidth ();
137+
138+ if (aWidth < cWidth)
139+ a = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], a);
140+ if (bWidth < cWidth)
141+ b = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], b);
142+
143+ return rewriter.create <arith::MulIOp>(loc, resultTypes, a, b);
144+ }
130145 }
131146
132147 // tosa::NegateOp
@@ -931,7 +946,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
931946 auto loc = operation->getLoc ();
932947 auto rank =
933948 cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
934- auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
949+ // For the mul op we need to avoid expanding the rank of the optional shift
950+ // input.
951+ auto operandsToExpand =
952+ isa<tosa::MulOp>(operation) ? operands.take_front (2 ) : operands;
953+
954+ auto expandedOperands =
955+ expandInputRanks (rewriter, loc, operandsToExpand, rank);
935956 auto [targetShape, masterOperands] =
936957 computeTargetShape (rewriter, loc, indexPool, expandedOperands);
937958 auto broadcastOperands = broadcastDynamicDimensions (
0 commit comments