@@ -2250,66 +2250,91 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
2250
2250
Value zeropoint = operands[2 ];
2251
2251
2252
2252
auto operandTy = cast<Torch::ValueTensorType>(operand.getType ());
2253
-
2254
- auto operandETy = operandTy.getDtype ();
2255
2253
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType ());
2256
2254
if (!scaleTy || !scaleTy.hasSizes ())
2257
2255
return rewriter.notifyMatchFailure (binder.op , " requires known rank" );
2258
2256
if (!resultType.hasDtype ())
2259
2257
return rewriter.notifyMatchFailure (binder.op ,
2260
2258
" requires known result dtype" );
2261
2259
2262
- bool rank0 = scaleTy.getSizes ().size () == 0 ;
2263
- bool length1 =
2264
- scaleTy.getSizes ().size () == 1 && scaleTy.getSizes ()[0 ] == 1 ;
2265
-
2266
- if (!rank0 && !length1)
2267
- return rewriter.notifyMatchFailure (binder.op ,
2268
- " unimplemented: non-scalar scale" );
2260
+ int64_t scaleRank = scaleTy.getSizes ().size ();
2261
+ if (scaleRank > 1 )
2262
+ return rewriter.notifyMatchFailure (
2263
+ binder.op , " unimplemented: only per-tensor or per-axis "
2264
+ " quantization supported" );
2269
2265
auto qTensorTy = getQTorchTypeFromTorchIntType (operandTy);
2270
2266
if (!qTensorTy) {
2271
2267
return rewriter.notifyMatchFailure (binder.op ,
2272
2268
" unsupported result dtype" );
2273
2269
}
2274
2270
2275
- scale = rewriter.create <Torch::AtenItemOp>(
2276
- loc, rewriter.getType <Torch::FloatType>(), scale);
2277
-
2271
+ auto operandETy = operandTy.getDtype ();
2278
2272
bool fpOperand = isa<mlir::FloatType>(operandETy);
2279
- Type zeropointTy = rewriter.getType <Torch::IntType>();
2280
- if (fpOperand)
2281
- zeropointTy = rewriter.getType <Torch::FloatType>();
2282
-
2283
- zeropoint =
2284
- rewriter.create <Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2285
-
2286
- if (fpOperand) {
2287
- Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
2288
- Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
2289
- auto tyVal = Torch::getScalarTypeForType (resultType.getDtype ());
2290
- Value tyConst = rewriter.create <Torch::ConstantIntOp>(
2291
- loc, rewriter.getType <Torch::IntType>(),
2292
- rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
2293
- static_cast <int64_t >(tyVal)));
2294
- Value toDtype = rewriter.create <Torch::AtenToDtypeOp>(
2295
- loc, resultType, operand, tyConst,
2296
- /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
2297
- /* memory_format=*/ none);
2298
-
2299
- Value one = rewriter.create <Torch::ConstantFloatOp>(
2300
- loc, rewriter.getF64FloatAttr (1.0 ));
2301
- Value sub = rewriter.create <Torch::AtenSubScalarOp>(
2302
- loc, resultType, toDtype, zeropoint, one);
2303
- rewriter.replaceOpWithNewOp <Torch::AtenMulScalarOp>(
2304
- binder.op , resultType, sub, scale);
2273
+ bool isPerTensorQuantization = false ;
2274
+ if (scaleRank == 0 ||
2275
+ llvm::all_of (scaleTy.getSizes (), [](int64_t s) { return s == 1 ; }))
2276
+ isPerTensorQuantization = true ;
2277
+
2278
+ // (TODO) Case: Per-Channel Quantization for floating point input.
2279
+ if (scaleRank == 1 && fpOperand)
2280
+ return rewriter.notifyMatchFailure (
2281
+ binder.op , " unimplemented: support for per-Channel Quantization "
2282
+ " for floating point input not present" );
2283
+
2284
+ if (isPerTensorQuantization) {
2285
+ scale = rewriter.create <Torch::AtenItemOp>(
2286
+ loc, rewriter.getType <Torch::FloatType>(), scale);
2287
+
2288
+ Type zeropointTy = rewriter.getType <Torch::IntType>();
2289
+ if (fpOperand)
2290
+ zeropointTy = rewriter.getType <Torch::FloatType>();
2291
+ zeropoint =
2292
+ rewriter.create <Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2293
+ }
2294
+
2295
+ if (!fpOperand) {
2296
+ Value quantize;
2297
+ // Case 1: Per-Tensor Quantization for non-floating point input.
2298
+ if (isPerTensorQuantization) {
2299
+ quantize =
2300
+ rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
2301
+ loc, qTensorTy, operand, scale, zeropoint);
2302
+ } else {
2303
+ // Case 2: Per-Channel Quantization for non-floating point input.
2304
+ int64_t axis;
2305
+ if (binder.s64IntegerAttr (axis, " axis" , 1 ))
2306
+ return failure ();
2307
+
2308
+ Value cstAxis = rewriter.create <Torch::ConstantIntOp>(
2309
+ loc, rewriter.getI64IntegerAttr (axis));
2310
+ quantize =
2311
+ rewriter.create <Torch::Aten_MakePerChannelQuantizedTensorOp>(
2312
+ loc, qTensorTy, operand, scale, zeropoint, cstAxis);
2313
+ }
2314
+ rewriter.replaceOpWithNewOp <Torch::AtenDequantizeSelfOp>(
2315
+ binder.op , resultType, quantize);
2305
2316
return success ();
2306
2317
}
2307
2318
2308
- auto quantize =
2309
- rewriter.create <Torch::Aten_MakePerTensorQuantizedTensorOp>(
2310
- loc, qTensorTy, operand, scale, zeropoint);
2311
- rewriter.replaceOpWithNewOp <Torch::AtenDequantizeSelfOp>(
2312
- binder.op , resultType, quantize);
2319
+ // Case 3: Per-Tensor Quantization for floating point input.
2320
+ Value none = rewriter.create <Torch::ConstantNoneOp>(loc);
2321
+ Value cstFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
2322
+ auto tyVal = Torch::getScalarTypeForType (resultType.getDtype ());
2323
+ Value tyConst = rewriter.create <Torch::ConstantIntOp>(
2324
+ loc, rewriter.getType <Torch::IntType>(),
2325
+ rewriter.getIntegerAttr (rewriter.getIntegerType (64 ),
2326
+ static_cast <int64_t >(tyVal)));
2327
+ Value toDtype = rewriter.create <Torch::AtenToDtypeOp>(
2328
+ loc, resultType, operand, tyConst,
2329
+ /* non_blocking=*/ cstFalse, /* copy=*/ cstFalse,
2330
+ /* memory_format=*/ none);
2331
+
2332
+ Value one = rewriter.create <Torch::ConstantFloatOp>(
2333
+ loc, rewriter.getF64FloatAttr (1.0 ));
2334
+ Value sub = rewriter.create <Torch::AtenSubScalarOp>(
2335
+ loc, resultType, toDtype, zeropoint, one);
2336
+ rewriter.replaceOpWithNewOp <Torch::AtenMulScalarOp>(
2337
+ binder.op , resultType, sub, scale);
2313
2338
return success ();
2314
2339
});
2315
2340
patterns.onOp (" Div" , 7 ,
0 commit comments