Skip to content

Commit 7e75e5f

Browse files
author
ssjia
committed
Update on "[ET-VK] Implemement linear_dq8ta_q4gsw"
Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned]
1 parent 06ffc8c commit 7e75e5f

File tree

3 files changed

+14
-8
lines changed

3 files changed

+14
-8
lines changed

backends/vulkan/targets.bzl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def define_common_targets(is_fbcode = False):
330330
"//executorch/exir:tensor",
331331
"//executorch/exir/backend/canonical_partitioners:config_partitioner_lib",
332332
"//executorch/backends/vulkan/serialization:lib",
333-
]
333+
],
334+
typing = True,
334335
)
335336

336337
runtime.python_library(

backends/vulkan/test/custom_ops/q4gsw_linear.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,14 @@ std::vector<TestCase> generate_quantized_linear_test_cases() {
267267
config.test_case_name = generated_test_case_name;
268268

269269
for (const auto& storage_type : storage_types) {
270-
// Test both activation+weight quantized and weight only quantized
271-
test_cases.push_back(
272-
create_test_case_from_config(config, storage_type, vkapi::kFloat));
270+
// Test both activation+weight quantized and weight only quantized, but
271+
// only if the current device supports int8 dot product
272+
if (vkcompute::api::context()
273+
->adapter_ptr()
274+
->supports_int8_dot_product()) {
275+
test_cases.push_back(
276+
create_test_case_from_config(config, storage_type, vkapi::kFloat));
277+
}
273278

274279
LinearConfig wo_quant_config = config;
275280
wo_quant_config.op_name = "linear_q4gsw";

backends/vulkan/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,17 +1142,17 @@ def maybe_skip_q_dq_arg_chain(
11421142

11431143
# If the arg is a view copy node, check if the original node is a dequant node
11441144
if is_dequant_node(arg) or (
1145-
is_view_copy_node(arg) and is_dequant_node(arg.args[0])
1145+
is_view_copy_node(arg) and is_dequant_node(arg.args[0]) # pyre-ignore[6]
11461146
):
1147+
dequant_node = arg
11471148
if is_view_copy_node(arg):
11481149
dequant_node = arg.args[0]
1149-
else:
1150-
dequant_node = arg
11511150

1152-
quant_node = dequant_node.args[0]
1151+
quant_node = dequant_node.args[0] # pyre-ignore[16]
11531152
assert isinstance(quant_node, torch.fx.Node)
11541153
source_arg = quant_node.args[0]
11551154
assert isinstance(source_arg, torch.fx.Node)
1155+
assert isinstance(dequant_node, torch.fx.Node)
11561156
return source_arg, quant_node, dequant_node
11571157
else:
11581158
return arg, None, None

0 commit comments

Comments
 (0)