Skip to content

Commit ffe6839

Browse files
authored
Merge branch 'main' into addScript
2 parents 13261da + fcc7f3b commit ffe6839

File tree

19 files changed

+850
-233
lines changed

19 files changed

+850
-233
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
HistogramObserver,
3939
MinMaxObserver,
4040
MovingAverageMinMaxObserver,
41-
MovingAveragePerChannelMinMaxObserver,
4241
ObserverOrFakeQuantizeConstructor,
4342
PerChannelMinMaxObserver,
4443
PlaceholderObserver,
@@ -95,24 +94,26 @@ def get_symmetric_quantization_config(
9594
**extra_args,
9695
),
9796
)
97+
98+
# Setup quantization config for weights
9899
weight_qscheme = (
99100
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
100101
)
101102
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
102103
MinMaxObserver
103104
)
105+
# Determine the right observer/fake-quant constructor
104106
if is_qat:
105-
# TODO: qat + per channel?
106-
weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize
107-
elif is_per_channel:
108-
weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver
107+
# Set plain fake-quant with true min/max
108+
weight_observer_or_fake_quant_ctr = FakeQuantize
109+
else:
110+
# PTQ: set min/max observer
111+
weight_observer_or_fake_quant_ctr = (
112+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
113+
)
114+
115+
extra_args = {"eps": 2**-12}
109116

110-
extra_args: Dict[str, Any] = {"eps": 2**-12}
111-
if is_qat:
112-
if weight_qscheme == torch.per_tensor_symmetric:
113-
extra_args["observer"] = MovingAverageMinMaxObserver
114-
else:
115-
extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item]
116117
weight_quantization_spec = QuantizationSpec(
117118
dtype=torch.int8,
118119
quant_min=weight_qmin,

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def quantize_and_export_to_edge(
278278
dump_graphs: bool = False,
279279
constant_methods: Optional[dict[str, object]] = None,
280280
calibration_data: Optional[list[tuple[object, ...]]] = None,
281+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
281282
) -> EdgeProgramManager:
282283
"""
283284
Trace, quantize and lower a model/inputs pair to edge IR.
@@ -294,6 +295,7 @@ def quantize_and_export_to_edge(
294295
quantized_model,
295296
dump_graphs=dump_graphs,
296297
constant_methods=constant_methods,
298+
core_aten_exceptions=core_aten_exceptions,
297299
)
298300

299301

backends/vulkan/runtime/graph/ops/glsl/dequantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_buffer:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_buffer
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.glsl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ void dequantize_per_tensor() {
139139
[[unroll]] for (int i = 0; i < 4; ++i) {
140140
IN_T qvalue = IN_T(intex[i]);
141141
OUT_T value = dequantize_val(qvalue, scale, zero_point);
142-
outtex[i] = value;
142+
$if OUT_DTYPE == "double":
143+
outtex[i] = float(value);
144+
$else:
145+
outtex[i] = value;
143146
}
144147
write_texel(t_out, pos, outtex);
145148
}
@@ -177,7 +180,10 @@ void dequantize_per_token() {
177180
[[unroll]] for (int i = 0; i < 4; ++i) {
178181
IN_T qvalue = IN_T(intex[i]);
179182
OUT_T value = dequantize_val(qvalue, scale_val, zero_point_val);
180-
outtex[i] = value;
183+
$if OUT_DTYPE == "double":
184+
outtex[i] = float(value);
185+
$else:
186+
outtex[i] = value;
181187
}
182188

183189
write_texel(t_out, pos, outtex);

backends/vulkan/runtime/graph/ops/glsl/dequantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ dequantize_texture:
1111
OUT_DTYPE:
1212
- VALUE: half
1313
- VALUE: float
14+
- VALUE: double
1415
shader_variants:
1516
- NAME: dequantize_per_tensor_texture3d
1617
MODE: per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_buffer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_buffer:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/glsl/quantize_texture.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ quantize_texture:
77
IN_DTYPE:
88
- VALUE: half
99
- VALUE: float
10+
- VALUE: double
1011
OUT_DTYPE:
1112
- VALUE: uint8
1213
- VALUE: int8

backends/vulkan/runtime/graph/ops/impl/Quantize.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ void quantize_per_tensor_impl(
188188

189189
// Verify input is a floating point type
190190
VK_CHECK_COND(
191+
graph.dtype_of(input) == vkapi::kDouble ||
191192
graph.dtype_of(input) == vkapi::kFloat ||
192193
graph.dtype_of(input) == vkapi::kHalf);
193194

@@ -214,6 +215,7 @@ void quantize_per_token_impl(
214215

215216
// Verify input is a floating point type
216217
VK_CHECK_COND(
218+
graph.dtype_of(input) == vkapi::kDouble ||
217219
graph.dtype_of(input) == vkapi::kFloat ||
218220
graph.dtype_of(input) == vkapi::kHalf);
219221

backends/vulkan/test/op_tests/dequantize_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ void test_vulkan_dequantize_per_tensor(
366366
vkcompute::utils::kBuffer,
367367
vkcompute::utils::kBuffer);
368368

369+
// Telling the system to expect a float instead of a double
370+
// since the shader can only return 32bit anyways
371+
if (out_dtype == at::kDouble) {
372+
out_dtype = at::kFloat;
373+
}
374+
369375
// Test with texture storage
370376
test_vulkan_dequantize_per_tensor_impl(
371377
input_sizes,
@@ -400,6 +406,12 @@ void test_vulkan_dequantize_per_token(
400406
vkcompute::utils::kBuffer,
401407
vkcompute::utils::kBuffer);
402408

409+
// Telling the system to expect a float instead of a double
410+
// since the shader can only return 32bit anyways
411+
if (out_dtype == at::kDouble) {
412+
out_dtype = at::kFloat;
413+
}
414+
403415
// Test with texture storage
404416
test_vulkan_dequantize_per_token_impl(
405417
input_sizes,
@@ -793,6 +805,24 @@ TEST(
793805
at::kHalf); // output dtype
794806
}
795807

808+
TEST(
809+
VulkanDequantizePerTensorTest,
810+
test_vulkan_dequantize_per_tensor_int8_to_double) {
811+
if (!vkcompute::api::context()
812+
->adapter_ptr()
813+
->has_full_int8_buffers_support()) {
814+
GTEST_SKIP();
815+
}
816+
test_vulkan_dequantize_per_tensor(
817+
{2, 3}, // input sizes
818+
0.05, // scale
819+
10, // zero_point
820+
-128, // quant_min
821+
127, // quant_max
822+
at::kChar, // input dtype
823+
at::kDouble); // output dtype
824+
}
825+
796826
void test_reference_dequantize_per_token(
797827
const std::vector<int>& input_sizes,
798828
const std::vector<float>& scales,
@@ -1288,3 +1318,24 @@ TEST(
12881318
at::kInt, // input dtype
12891319
at::kHalf); // output dtype
12901320
}
1321+
1322+
TEST(
1323+
VulkanDequantizePerTokenTest,
1324+
test_vulkan_dequantize_per_token_int8_to_double) {
1325+
if (!vkcompute::api::context()
1326+
->adapter_ptr()
1327+
->has_full_int8_buffers_support()) {
1328+
GTEST_SKIP();
1329+
}
1330+
std::vector<float> scales = {0.05, 0.001};
1331+
std::vector<int> zero_points = {10, -5};
1332+
1333+
test_vulkan_dequantize_per_token(
1334+
{2, 2}, // input sizes (2 tokens)
1335+
scales,
1336+
zero_points,
1337+
-128, // quant_min
1338+
127, // quant_max
1339+
at::kChar, // input dtype
1340+
at::kDouble); // output dtype
1341+
}

backends/vulkan/test/op_tests/quantize_test.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,12 @@ void test_vulkan_quantize_per_tensor(
315315
vkcompute::utils::kBuffer,
316316
vkcompute::utils::kBuffer);
317317

318+
// If the in_dtype is a double, convert to float for texture implementation
319+
// since they don't support 64bit as inputs
320+
if (in_dtype == at::kDouble) {
321+
in_dtype = at::kFloat;
322+
}
323+
318324
// Test with texture storage
319325
test_vulkan_quantize_per_tensor_impl(
320326
input_sizes,
@@ -349,6 +355,12 @@ void test_vulkan_quantize_per_token(
349355
vkcompute::utils::kBuffer,
350356
vkcompute::utils::kBuffer);
351357

358+
// If the in_dtype is a double, convert to float for texture implementation
359+
// since they don't support 64bit as inputs
360+
if (in_dtype == at::kDouble) {
361+
in_dtype = at::kFloat;
362+
}
363+
352364
// Test with texture storage
353365
test_vulkan_quantize_per_token_impl(
354366
input_sizes,
@@ -655,6 +667,24 @@ TEST(
655667
at::kChar); // output dtype
656668
}
657669

670+
TEST(
671+
VulkanQuantizePerTensorTest,
672+
test_vulkan_quantize_per_tensor_double_to_int8) {
673+
if (!vkcompute::api::context()
674+
->adapter_ptr()
675+
->has_full_int8_buffers_support()) {
676+
GTEST_SKIP();
677+
}
678+
test_vulkan_quantize_per_tensor(
679+
{2, 3}, // input sizes
680+
0.01, // scale
681+
1, // zero_point
682+
-128, // quant_min
683+
127, // quant_max
684+
at::kDouble, // input dtype
685+
at::kChar); // output dtype
686+
}
687+
658688
void test_reference_quantize_per_token(
659689
const std::vector<int>& input_sizes,
660690
const std::vector<float>& pre_scales,
@@ -1075,3 +1105,24 @@ TEST(VulkanQuantizePerTensorTest, test_vulkan_quantize_per_token_half_to_int8) {
10751105
at::kHalf, // input dtype
10761106
at::kChar); // output dtype
10771107
}
1108+
1109+
TEST(
1110+
VulkanQuantizePerTensorTest,
1111+
test_vulkan_quantize_per_token_double_to_int8) {
1112+
if (!vkcompute::api::context()
1113+
->adapter_ptr()
1114+
->has_full_int8_buffers_support()) {
1115+
GTEST_SKIP();
1116+
}
1117+
std::vector<float> scales = {0.1, 0.2};
1118+
std::vector<int> zero_points = {0, 5};
1119+
1120+
test_vulkan_quantize_per_token(
1121+
{2, 2}, // input sizes (2*2=4 tokens)
1122+
scales,
1123+
zero_points,
1124+
-128, // quant_min
1125+
127, // quant_max
1126+
at::kDouble, // input dtype
1127+
at::kChar); // output dtype
1128+
}

0 commit comments

Comments
 (0)