@@ -90,43 +90,59 @@ static Value createLinalgBodyCalculationForElementwiseOp(
9090 }
9191
9292 // tosa::MulOp
93- if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
94- return rewriter.create <arith::MulFOp>(loc, resultTypes, args);
95-
96- if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
97- Value a = args[0 ];
98- Value b = args[1 ];
99- auto shift =
100- cast<IntegerAttr>(op->getAttr (" shift" )).getValue ().getSExtValue ();
101- if (shift > 0 ) {
102- auto shiftConst =
103- rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
104- if (!a.getType ().isInteger (32 ))
105- a = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), a);
106-
107- if (!b.getType ().isInteger (32 ))
108- b = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), b);
109-
110- auto result = rewriter.create <tosa::ApplyScaleOp>(
111- loc, rewriter.getI32Type (), a, b, shiftConst,
112- rewriter.getBoolAttr (false ));
113-
114- if (elementTy.isInteger (32 ))
115- return result;
116-
117- return rewriter.create <arith::TruncIOp>(loc, elementTy, result);
93+ if (isa<tosa::MulOp>(op)) {
94+ auto shift_val = cast<tosa::MulOp>(op).getShift ();
95+ if (!elementTy.isInteger (32 ) && shift_val.getImpl ()) {
96+ (void )rewriter.notifyMatchFailure (
97+ op, " Cannot have shift value for non i32 output" );
98+ return nullptr ;
99+ };
100+
101+ if (isa<FloatType>(elementTy)) {
102+ return rewriter.create <arith::MulFOp>(loc, resultTypes, args[0 ], args[1 ]);
118103 }
119104
120- int aWidth = a.getType ().getIntOrFloatBitWidth ();
121- int bWidth = b.getType ().getIntOrFloatBitWidth ();
122- int cWidth = resultTypes[0 ].getIntOrFloatBitWidth ();
105+ if (isa<IntegerType>(elementTy)) {
106+ int32_t shift = 0 ;
107+ ElementsAttr shift_elem;
108+ if (shift_val.getImpl () &&
109+ matchPattern (shift_val, m_Constant (&shift_elem))) {
110+ // Explicit shift is set.
111+ shift = shift_elem.getValues <IntegerAttr>()[0 ].getInt ();
112+ }
113+
114+ Value a = args[0 ];
115+ Value b = args[1 ];
116+ if (shift > 0 ) {
117+ auto shiftConst =
118+ rewriter.create <arith::ConstantIntOp>(loc, shift, /* bitwidth=*/ 8 );
119+ if (!a.getType ().isInteger (32 ))
120+ a = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), a);
121+
122+ if (!b.getType ().isInteger (32 ))
123+ b = rewriter.create <arith::ExtSIOp>(loc, rewriter.getI32Type (), b);
124+
125+ auto result = rewriter.create <tosa::ApplyScaleOp>(
126+ loc, rewriter.getI32Type (), a, b, shiftConst,
127+ rewriter.getBoolAttr (false ));
123128
124- if (aWidth < cWidth)
125- a = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], a);
126- if (bWidth < cWidth)
127- b = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], b);
129+ if (elementTy.isInteger (32 ))
130+ return result;
128131
129- return rewriter.create <arith::MulIOp>(loc, resultTypes, a, b);
132+ return rewriter.create <arith::TruncIOp>(loc, elementTy, result);
133+ }
134+
135+ int aWidth = a.getType ().getIntOrFloatBitWidth ();
136+ int bWidth = b.getType ().getIntOrFloatBitWidth ();
137+ int cWidth = resultTypes[0 ].getIntOrFloatBitWidth ();
138+
139+ if (aWidth < cWidth)
140+ a = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], a);
141+ if (bWidth < cWidth)
142+ b = rewriter.create <arith::ExtSIOp>(loc, resultTypes[0 ], b);
143+
144+ return rewriter.create <arith::MulIOp>(loc, resultTypes, a, b);
145+ }
130146 }
131147
132148 // tosa::NegateOp
@@ -940,7 +956,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
940956 auto loc = operation->getLoc ();
941957 auto rank =
942958 cast<RankedTensorType>(operation->getResultTypes ().front ()).getRank ();
943- auto expandedOperands = expandInputRanks (rewriter, loc, operands, rank);
959+ // For the mul op we need to avoid expanding the rank of the optional shift
960+ // input.
961+ auto operandsToExpand =
962+ isa<tosa::MulOp>(operation) ? operands.take_front (2 ) : operands;
963+
964+ auto expandedOperands =
965+ expandInputRanks (rewriter, loc, operandsToExpand, rank);
944966 auto [targetShape, masterOperands] =
945967 computeTargetShape (rewriter, loc, indexPool, expandedOperands);
946968 auto broadcastOperands = broadcastDynamicDimensions (
0 commit comments