Skip to content

Commit 40b3469

Browse files
[ONNX] Add per-channel quantization support for DequantizeLinear op (#4090)
Fixes nod-ai/SHARK-ModelDev#914. --------- Signed-off-by: Vivek Khandelwal <[email protected]>
1 parent 3cd2985 commit 40b3469

File tree

2 files changed

+80
-43
lines changed

2 files changed

+80
-43
lines changed

lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Lines changed: 68 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,66 +2250,91 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
22502250
Value zeropoint = operands[2];
22512251

22522252
auto operandTy = cast<Torch::ValueTensorType>(operand.getType());
2253-
2254-
auto operandETy = operandTy.getDtype();
22552253
auto scaleTy = dyn_cast<Torch::ValueTensorType>(scale.getType());
22562254
if (!scaleTy || !scaleTy.hasSizes())
22572255
return rewriter.notifyMatchFailure(binder.op, "requires known rank");
22582256
if (!resultType.hasDtype())
22592257
return rewriter.notifyMatchFailure(binder.op,
22602258
"requires known result dtype");
22612259

2262-
bool rank0 = scaleTy.getSizes().size() == 0;
2263-
bool length1 =
2264-
scaleTy.getSizes().size() == 1 && scaleTy.getSizes()[0] == 1;
2265-
2266-
if (!rank0 && !length1)
2267-
return rewriter.notifyMatchFailure(binder.op,
2268-
"unimplemented: non-scalar scale");
2260+
int64_t scaleRank = scaleTy.getSizes().size();
2261+
if (scaleRank > 1)
2262+
return rewriter.notifyMatchFailure(
2263+
binder.op, "unimplemented: only per-tensor or per-axis "
2264+
"quantization supported");
22692265
auto qTensorTy = getQTorchTypeFromTorchIntType(operandTy);
22702266
if (!qTensorTy) {
22712267
return rewriter.notifyMatchFailure(binder.op,
22722268
"unsupported result dtype");
22732269
}
22742270

2275-
scale = rewriter.create<Torch::AtenItemOp>(
2276-
loc, rewriter.getType<Torch::FloatType>(), scale);
2277-
2271+
auto operandETy = operandTy.getDtype();
22782272
bool fpOperand = isa<mlir::FloatType>(operandETy);
2279-
Type zeropointTy = rewriter.getType<Torch::IntType>();
2280-
if (fpOperand)
2281-
zeropointTy = rewriter.getType<Torch::FloatType>();
2282-
2283-
zeropoint =
2284-
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2285-
2286-
if (fpOperand) {
2287-
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
2288-
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
2289-
auto tyVal = Torch::getScalarTypeForType(resultType.getDtype());
2290-
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
2291-
loc, rewriter.getType<Torch::IntType>(),
2292-
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
2293-
static_cast<int64_t>(tyVal)));
2294-
Value toDtype = rewriter.create<Torch::AtenToDtypeOp>(
2295-
loc, resultType, operand, tyConst,
2296-
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
2297-
/*memory_format=*/none);
2298-
2299-
Value one = rewriter.create<Torch::ConstantFloatOp>(
2300-
loc, rewriter.getF64FloatAttr(1.0));
2301-
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
2302-
loc, resultType, toDtype, zeropoint, one);
2303-
rewriter.replaceOpWithNewOp<Torch::AtenMulScalarOp>(
2304-
binder.op, resultType, sub, scale);
2273+
bool isPerTensorQuantization = false;
2274+
if (scaleRank == 0 ||
2275+
llvm::all_of(scaleTy.getSizes(), [](int64_t s) { return s == 1; }))
2276+
isPerTensorQuantization = true;
2277+
2278+
// (TODO) Case: Per-Channel Quantization for floating point input.
2279+
if (scaleRank == 1 && fpOperand)
2280+
return rewriter.notifyMatchFailure(
2281+
binder.op, "unimplemented: support for per-Channel Quantization "
2282+
"for floating point input not present");
2283+
2284+
if (isPerTensorQuantization) {
2285+
scale = rewriter.create<Torch::AtenItemOp>(
2286+
loc, rewriter.getType<Torch::FloatType>(), scale);
2287+
2288+
Type zeropointTy = rewriter.getType<Torch::IntType>();
2289+
if (fpOperand)
2290+
zeropointTy = rewriter.getType<Torch::FloatType>();
2291+
zeropoint =
2292+
rewriter.create<Torch::AtenItemOp>(loc, zeropointTy, zeropoint);
2293+
}
2294+
2295+
if (!fpOperand) {
2296+
Value quantize;
2297+
// Case 1: Per-Tensor Quantization for non-floating point input.
2298+
if (isPerTensorQuantization) {
2299+
quantize =
2300+
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
2301+
loc, qTensorTy, operand, scale, zeropoint);
2302+
} else {
2303+
// Case 2: Per-Channel Quantization for non-floating point input.
2304+
int64_t axis;
2305+
if (binder.s64IntegerAttr(axis, "axis", 1))
2306+
return failure();
2307+
2308+
Value cstAxis = rewriter.create<Torch::ConstantIntOp>(
2309+
loc, rewriter.getI64IntegerAttr(axis));
2310+
quantize =
2311+
rewriter.create<Torch::Aten_MakePerChannelQuantizedTensorOp>(
2312+
loc, qTensorTy, operand, scale, zeropoint, cstAxis);
2313+
}
2314+
rewriter.replaceOpWithNewOp<Torch::AtenDequantizeSelfOp>(
2315+
binder.op, resultType, quantize);
23052316
return success();
23062317
}
23072318

2308-
auto quantize =
2309-
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
2310-
loc, qTensorTy, operand, scale, zeropoint);
2311-
rewriter.replaceOpWithNewOp<Torch::AtenDequantizeSelfOp>(
2312-
binder.op, resultType, quantize);
2319+
// Case 3: Per-Tensor Quantization for floating point input.
2320+
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
2321+
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
2322+
auto tyVal = Torch::getScalarTypeForType(resultType.getDtype());
2323+
Value tyConst = rewriter.create<Torch::ConstantIntOp>(
2324+
loc, rewriter.getType<Torch::IntType>(),
2325+
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
2326+
static_cast<int64_t>(tyVal)));
2327+
Value toDtype = rewriter.create<Torch::AtenToDtypeOp>(
2328+
loc, resultType, operand, tyConst,
2329+
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
2330+
/*memory_format=*/none);
2331+
2332+
Value one = rewriter.create<Torch::ConstantFloatOp>(
2333+
loc, rewriter.getF64FloatAttr(1.0));
2334+
Value sub = rewriter.create<Torch::AtenSubScalarOp>(
2335+
loc, resultType, toDtype, zeropoint, one);
2336+
rewriter.replaceOpWithNewOp<Torch::AtenMulScalarOp>(
2337+
binder.op, resultType, sub, scale);
23132338
return success();
23142339
});
23152340
patterns.onOp("Div", 7,

test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,18 @@ func.func @test_dequantizelinear_fp8(%arg0: !torch.vtensor<[6],f8E4M3FN>, %arg1:
817817

818818
// -----
819819

820+
// CHECK-LABEL: @test_dequantizelinear_per_channel_si8
821+
func.func @test_dequantizelinear_per_channel_si8(%arg0: !torch.vtensor<[64,3,3,3],si8>, %arg1: !torch.vtensor<[64],f32>, %arg2: !torch.vtensor<[64],si8>) -> !torch.vtensor<[64,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64} {
822+
%0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[64,3,3,3],si8>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],si8>) -> !torch.vtensor<[64,3,3,3],f32>
823+
// CHECK: %[[AXIS:.+]] = torch.constant.int 0
824+
// CHECK: %[[MAKE:.+]] = torch.aten._make_per_channel_quantized_tensor %arg0, %arg1, %arg2,
825+
// CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]]
826+
// CHECK: return %[[DEQ]]
827+
return %0 : !torch.vtensor<[64,3,3,3],f32>
828+
}
829+
830+
// -----
831+
820832
// CHECK-LABEL: @test_div_bcast
821833
func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
822834
// CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32>

0 commit comments

Comments
 (0)