33func.func private @XlaCallModule_quant.fake_quant.impl_0 (tensor <1 x28 x28 x3 xf32 >) -> tensor <1 x28 x28 x3 xf32 >
44func.func private @XlaCallModule_quant.fake_quant.impl_5_0 (tensor <2 x1 x1 x1 xf32 >) -> tensor <2 x1 x1 x1 xf32 >
55func.func private @XlaCallModule_quant.fake_quant.impl_17_0 (tensor <1 x30 x30 x2 xf32 >) -> tensor <1 x30 x30 x2 xf32 >
6+ func.func private @XlaCallModule_quant.fake_quant.impl_i2_0 (tensor <1 x4 xf32 >) -> tensor <1 x4 xf32 >
7+ func.func private @XlaCallModule_quant.fake_quant.impl_i2_1 (tensor <1 x4 xf32 >) -> tensor <1 x4 xf32 >
68// CHECK-LABEL: func.func @serving_default
79func.func @serving_default (%arg0: tensor <1 x28 x28 x3 xf32 >) -> (tensor <1 x30 x30 x2 xf32 >) {
810 %cst = arith.constant dense <[[0 , 0 ], [1 , 1 ], [1 , 1 ], [0 , 0 ]]> : tensor <4 x2 xi32 >
@@ -22,4 +24,15 @@ func.func @serving_default(%arg0: tensor<1x28x28x3xf32>) -> (tensor<1x30x30x2xf3
2224 // CHECK-OFF: %[[DEQUANT2:.+]] = "tfl.dequantize"(%[[QUANT2]]) : (tensor<1x30x30x2x!quant.uniform<i8:f32, 0.018049469217658043:8>>) -> tensor<1x30x30x2xf32>
2325 %5 = stablehlo.composite " quant.fake_quant" %4 {composite_attributes = {dtype = " i8" , narrow_range = false , scale = dense <0.0180494692 > : tensor <1 xf32 >, zero_point = dense <8 > : tensor <1 xi32 >}, decomposition = @XlaCallModule_quant.fake_quant.impl_17_0 } : (tensor <1 x30 x30 x2 xf32 >) -> tensor <1 x30 x30 x2 xf32 >
2426 return %5 : tensor <1 x30 x30 x2 xf32 >
27+ }
28+
29+ // CHECK-LABEL: func.func @i2_test
30+ func.func @i2_test (%arg0: tensor <1 x4 xf32 >) -> (tensor <1 x4 xf32 >) {
31+ // CHECK: %[[QUANT0:.+]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x4x!quant.uniform<i2:f32, 1.000000e+00>>}> : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<i2:f32, 1.000000e+00>>
32+ // CHECK: %[[DEQUANT0:.+]] = "tfl.dequantize"(%[[QUANT0]]) : (tensor<1x4x!quant.uniform<i2:f32, 1.000000e+00>>) -> tensor<1x4xf32>
33+ %0 = stablehlo.composite " quant.fake_quant" %arg0 {composite_attributes = {dtype = " i2" , narrow_range = false , scale = dense <1.0 > : tensor <1 xf32 >, zero_point = dense <0 > : tensor <1 xi32 >}, decomposition = @XlaCallModule_quant.fake_quant.impl_i2_0 } : (tensor <1 x4 xf32 >) -> tensor <1 x4 xf32 >
34+ // CHECK: %[[QUANT1:.+]] = "tfl.quantize"(%[[DEQUANT0]]) <{qtype = tensor<1x4x!quant.uniform<i2<-1:1>:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>}> : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform<i2<-1:1>:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>
35+ // CHECK: %[[DEQUANT1:.+]] = "tfl.dequantize"(%[[QUANT1]]) : (tensor<1x4x!quant.uniform<i2<-1:1>:f32:1, {1.000000e+00,2.000000e+00,3.000000e+00,4.000000e+00}>>) -> tensor<1x4xf32>
36+ %1 = stablehlo.composite " quant.fake_quant" %0 {composite_attributes = {dtype = " i2" , narrow_range = true , quantization_dimension = 1 : i32 , scale = dense <[1.0 , 2.0 , 3.0 , 4.0 ]> : tensor <4 xf32 >}, decomposition = @XlaCallModule_quant.fake_quant.impl_i2_1 } : (tensor <1 x4 xf32 >) -> tensor <1 x4 xf32 >
37+ return %1 : tensor <1 x4 xf32 >
2538}
0 commit comments