@@ -25,14 +25,14 @@ func.func @test_quantization_per_channel(%arg0: !torch.vtensor<[4,3,7,7],f32>) -
25
25
%0 = torch.vtensor.literal (dense <[4.000000e-01 , 1.000000e-01 , 2.000000e-01 , 3.000000e-01 ]> : tensor <4 xf32 >) : !torch.vtensor <[4 ],f32 >
26
26
%1 = torch.vtensor.literal (dense <[4 , 1 , 2 , 3 ]> : tensor <4 xsi8 >) : !torch.vtensor <[4 ],si8 >
27
27
%int12 = torch.constant.int 12
28
- %int1 = torch.constant.int 1
28
+ %zero = torch.constant.int 0
29
29
// CHECK: %[[QUANT:.+]] = stablehlo.uniform_quantize %[[ARG0]]
30
- // CHECK-SAME: (tensor<4x3x7x7xf32>) -> tensor<4x3x7x7x!quant.uniform<i8:f32:1 , {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>
31
- %2 = torch.aten.quantize_per_channel %arg0 , %0 , %1 , %int1 , %int12 : !torch.vtensor <[4 ,3 ,7 ,7 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.vtensor <[4 ],si8 >, !torch.int , !torch.int -> !torch.vtensor <[4 ,3 ,7 ,7 ],!torch.qint8 >
30
+ // CHECK-SAME: (tensor<4x3x7x7xf32>) -> tensor<4x3x7x7x!quant.uniform<i8:f32:0 , {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>
31
+ %2 = torch.aten.quantize_per_channel %arg0 , %0 , %1 , %zero , %int12 : !torch.vtensor <[4 ,3 ,7 ,7 ],f32 >, !torch.vtensor <[4 ],f32 >, !torch.vtensor <[4 ],si8 >, !torch.int , !torch.int -> !torch.vtensor <[4 ,3 ,7 ,7 ],!torch.qint8 >
32
32
%3 = torch.aten.int_repr %2 : !torch.vtensor <[4 ,3 ,7 ,7 ],!torch.qint8 > -> !torch.vtensor <[4 ,3 ,7 ,7 ],si8 >
33
33
// CHECK: %[[DEQ:.+]] = stablehlo.uniform_dequantize %[[QUANT]]
34
- // CHECK-SAME: (tensor<4x3x7x7x!quant.uniform<i8:f32:1 , {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>) -> tensor<4x3x7x7xf32>
35
- %4 = torch.aten._make_per_channel_quantized_tensor %3 , %0 , %1 , %int1 : !torch.vtensor <[4 ,3 ,7 ,7 ],si8 >, !torch.vtensor <[4 ],f32 >, !torch.vtensor <[4 ],si8 >, !torch.int -> !torch.vtensor <[4 ,3 ,7 ,7 ],!torch.qint8 >
34
+ // CHECK-SAME: (tensor<4x3x7x7x!quant.uniform<i8:f32:0 , {0.4{{.*}}:4,0.1{{.*}}:1,0.2{{.*}}:2,0.3{{.*}}:3}>>) -> tensor<4x3x7x7xf32>
35
+ %4 = torch.aten._make_per_channel_quantized_tensor %3 , %0 , %1 , %zero : !torch.vtensor <[4 ,3 ,7 ,7 ],si8 >, !torch.vtensor <[4 ],f32 >, !torch.vtensor <[4 ],si8 >, !torch.int -> !torch.vtensor <[4 ,3 ,7 ,7 ],!torch.qint8 >
36
36
%5 = torch.aten.dequantize.self %4 : !torch.vtensor <[4 ,3 ,7 ,7 ],!torch.qint8 > -> !torch.vtensor <[4 ,3 ,7 ,7 ],f32 >
37
37
return %5 : !torch.vtensor <[4 ,3 ,7 ,7 ],f32 >
38
38
}
0 commit comments