@@ -186,56 +186,63 @@ static Value createLinalgBodyCalculationForElementwiseOp(
186
186
if (isa<tosa::NegateOp>(op)) {
187
187
auto negate = cast<tosa::NegateOp>(op);
188
188
189
+ int64_t inZp = 0 , outZp = 0 ;
189
190
FailureOr<int64_t > maybeInZp = negate.getInput1ZeroPoint ();
190
- if (failed (maybeInZp)) {
191
- (void )rewriter.notifyMatchFailure (
192
- op, " input1 zero point cannot be statically determined" );
193
- return nullptr ;
194
- }
195
-
196
191
FailureOr<int64_t > maybeOutZp = negate.getOutputZeroPoint ();
197
- if (failed (maybeOutZp)) {
198
- (void )rewriter.notifyMatchFailure (
199
- op, " output zero point cannot be statically determined" );
200
- return nullptr ;
201
- }
202
-
203
- int64_t inZp = *maybeInZp;
204
- int64_t outZp = *maybeOutZp;
192
+ bool hasInZp = !failed (maybeInZp);
193
+ bool hasOutZp = !failed (maybeOutZp);
194
+ if (hasInZp)
195
+ inZp = *maybeInZp;
196
+ if (hasOutZp)
197
+ outZp = *maybeOutZp;
205
198
206
199
if (isa<FloatType>(elementTy))
207
200
return arith::NegFOp::create (rewriter, loc, resultTypes, args[0 ]);
208
201
209
202
if (isa<IntegerType>(elementTy)) {
210
- if (!inZp && !outZp) {
203
+ if (hasInZp && hasOutZp && !inZp && !outZp) {
211
204
auto constant = arith::ConstantOp::create (
212
205
rewriter, loc, IntegerAttr::get (elementTy, 0 ));
213
206
return arith::SubIOp::create (rewriter, loc, resultTypes, constant,
214
207
args[0 ]);
215
208
}
216
209
210
+ Value zpAddValue;
211
+ Type intermediateType;
217
212
// Compute the maximum value that can occur in the intermediate buffer.
218
213
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth ();
219
- const int64_t zpAdd = inZp + outZp;
220
- const int64_t maxValue =
221
- APInt::getSignedMaxValue (inputBitWidth).getSExtValue () +
222
- std::abs (zpAdd) + 1 ;
223
-
224
- // Convert that maximum value into the maximum bitwidth needed to
225
- // represent it. We assume 48-bit numbers may be supported further in
226
- // the pipeline.
227
214
int intermediateBitWidth = 64 ;
228
- if (maxValue <= APInt::getSignedMaxValue (16 ).getSExtValue ()) {
229
- intermediateBitWidth = 16 ;
230
- } else if (maxValue <= APInt::getSignedMaxValue (32 ).getSExtValue ()) {
231
- intermediateBitWidth = 32 ;
232
- } else if (maxValue <= APInt::getSignedMaxValue (48 ).getSExtValue ()) {
233
- intermediateBitWidth = 48 ;
234
- }
235
215
236
- Type intermediateType = rewriter.getIntegerType (intermediateBitWidth);
237
- Value zpAddValue = arith::ConstantOp::create (
238
- rewriter, loc, rewriter.getIntegerAttr (intermediateType, zpAdd));
216
+ if (hasInZp && hasOutZp) {
217
+ // Compute the maximum value that can occur in the intermediate buffer.
218
+ const int64_t zpAdd = inZp + outZp;
219
+ const int64_t maxValue =
220
+ APInt::getSignedMaxValue (inputBitWidth).getSExtValue () +
221
+ std::abs (zpAdd) + 1 ;
222
+
223
+ // Convert that maximum value into the maximum bitwidth needed to
224
+ // represent it. We assume 48-bit numbers may be supported further in
225
+ // the pipeline.
226
+ if (maxValue <= APInt::getSignedMaxValue (16 ).getSExtValue ()) {
227
+ intermediateBitWidth = 16 ;
228
+ } else if (maxValue <= APInt::getSignedMaxValue (32 ).getSExtValue ()) {
229
+ intermediateBitWidth = 32 ;
230
+ } else if (maxValue <= APInt::getSignedMaxValue (48 ).getSExtValue ()) {
231
+ intermediateBitWidth = 48 ;
232
+ }
233
+
234
+ intermediateType = rewriter.getIntegerType (intermediateBitWidth);
235
+ zpAddValue = rewriter.create <arith::ConstantOp>(
236
+ loc, rewriter.getIntegerAttr (intermediateType, zpAdd));
237
+ } else {
238
+ intermediateType = rewriter.getIntegerType (intermediateBitWidth);
239
+ auto arg1 =
240
+ rewriter.create <arith::ExtSIOp>(loc, intermediateType, args[1 ]);
241
+ auto arg2 =
242
+ rewriter.create <arith::ExtSIOp>(loc, intermediateType, args[2 ]);
243
+ zpAddValue =
244
+ rewriter.create <arith::AddIOp>(loc, intermediateType, arg1, arg2);
245
+ }
239
246
240
247
// The negation can be applied by doing:
241
248
// outputValue = inZp + outZp - inputValue
@@ -1013,9 +1020,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
1013
1020
else
1014
1021
return operands.take_front (3 );
1015
1022
}
1016
- // Input1_zp and output_zp cannot broadcast
1017
- if (isa<tosa::NegateOp>(operation))
1023
+ if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1024
+ FailureOr<int64_t > maybeInZp = negate.getInput1ZeroPoint ();
1025
+ FailureOr<int64_t > maybeOutZp = negate.getOutputZeroPoint ();
1026
+ if (failed (maybeOutZp) && failed (maybeInZp))
1027
+ return operands;
1028
+ // Input1_zp and output_zp cannot broadcast when they are constants.
1018
1029
return operands.take_front (1 );
1030
+ }
1019
1031
return operands;
1020
1032
}
1021
1033
0 commit comments