Skip to content

Commit cc302d9

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK][ez] Enable max_pool2d.default"
max_pool2d_with_indices is already implemented; this diff enables max_pool2d as well by just re-using the same implementation. Differential Revision: [D81513446](https://our.internmc.facebook.com/intern/diff/D81513446/) [ghstack-poisoned]
1 parent bd98a93 commit cc302d9

File tree

7 files changed

+132
-65
lines changed

7 files changed

+132
-65
lines changed

backends/vulkan/patterns/quantized_linear.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,40 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
130130

131131
self.match_found = True
132132

133+
def is_weight_only_quantized(self) -> bool:
134+
return self.quantize_input_node is None
135+
136+
def is_weight_pergroup_quantized(self) -> bool:
137+
weight_shape = self.weight_node.meta["val"].shape
138+
scales_shape = self.weight_scales_node.meta["val"].shape
139+
if len(scales_shape) != 2:
140+
return False
141+
142+
# Check that:
143+
# height dim of scales is same as height dim of weight (N / output channels dim)
144+
# width dim of weight (K / in channels dim) is divisible by width dim of scales
145+
# (number of quantization groups)
146+
return scales_shape[-2] == weight_shape[-2] and (
147+
weight_shape[-1] % scales_shape[-1] == 0
148+
)
149+
150+
def is_weight_perchannel_quantized(self) -> bool:
151+
weight_shape = self.weight_node.meta["val"].shape
152+
scales_shape = self.weight_scales_node.meta["val"].shape
153+
if len(scales_shape) != 1:
154+
return False
155+
156+
# scales should have same size as weight's output channels dim
157+
return scales_shape[0] == weight_shape[-2]
158+
159+
def is_input_static_per_tensor_quantized(self) -> bool:
160+
if self.quantize_input_node is None:
161+
return False
162+
163+
# For static quantization per tensor quantization, the scales and zeros
164+
# are scalars.
165+
return isinstance(self.input_scales_node, float)
166+
133167

134168
linear_anchor_nodes = {
135169
exir_ops.edge.aten.linear.default,
@@ -227,18 +261,10 @@ def make_linear_q4ga_op(
227261
ep: ExportedProgram,
228262
graph_module: torch.fx.GraphModule,
229263
match: QuantizedLinearMatch,
264+
weight_tensor: torch.Tensor,
265+
weight_scales_tensor: torch.Tensor,
266+
weight_zeros_tensor: torch.Tensor,
230267
):
231-
weight_tensor = get_param_tensor(ep, match.weight_node)
232-
assert weight_tensor is not None
233-
234-
assert match.weight_scales_node is not None
235-
weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node)
236-
assert weight_scales_tensor is not None
237-
238-
assert match.weight_zeros_node is not None
239-
weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node)
240-
assert weight_zeros_tensor is not None
241-
242268
packed_quantized_weight_tensor = pack_4bit_weight_tensor(weight_tensor)
243269
utils.update_program_state_dict(
244270
ep, match.weight_node.name, packed_quantized_weight_tensor
@@ -281,23 +307,8 @@ def make_linear_q8ta_q8csw_custom_op(
281307
ep: ExportedProgram,
282308
graph_module: torch.fx.GraphModule,
283309
match: QuantizedLinearMatch,
310+
weight_tensor: torch.Tensor,
284311
):
285-
weight_tensor = get_param_tensor(ep, match.weight_node)
286-
assert weight_tensor is not None
287-
288-
assert match.weight_scales_node is not None
289-
weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node)
290-
assert weight_scales_tensor is not None
291-
292-
assert match.weight_zeros_node is not None
293-
weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node)
294-
assert weight_zeros_tensor is not None
295-
296-
bias_tensor = None
297-
if match.bias_node is not None:
298-
bias_tensor = get_param_tensor(ep, match.bias_node)
299-
assert bias_tensor is not None
300-
301312
first_graph_node = list(graph_module.graph.nodes)[0]
302313
with graph_module.graph.inserting_before(first_graph_node):
303314
weight_tensor_name = utils.get_tensor_name(ep, match.weight_node)
@@ -340,7 +351,40 @@ def replace_quantized_linear_patterns(
340351
graph_module: torch.fx.GraphModule,
341352
match: QuantizedLinearMatch,
342353
):
343-
if match.quantize_input_node is None:
344-
make_linear_q4ga_op(ep, graph_module, match)
345-
else:
346-
make_linear_q8ta_q8csw_custom_op(ep, graph_module, match)
354+
# Extract relevant tensors
355+
weight_tensor = get_param_tensor(ep, match.weight_node)
356+
assert weight_tensor is not None
357+
358+
assert match.weight_scales_node is not None
359+
weight_scales_tensor = get_param_tensor(ep, match.weight_scales_node)
360+
assert weight_scales_tensor is not None
361+
362+
assert match.weight_zeros_node is not None
363+
weight_zeros_tensor = get_param_tensor(ep, match.weight_zeros_node)
364+
assert weight_zeros_tensor is not None
365+
366+
# Biases not supported at the moment
367+
if match.bias_node is not None:
368+
return
369+
370+
# Route to appropriate custom op
371+
if (
372+
match.is_weight_only_quantized()
373+
and match.is_weight_pergroup_quantized()
374+
and utils.is_in_4bit_range(weight_tensor)
375+
):
376+
make_linear_q4ga_op(
377+
ep,
378+
graph_module,
379+
match,
380+
weight_tensor,
381+
weight_scales_tensor,
382+
weight_zeros_tensor,
383+
)
384+
elif (
385+
match.is_input_static_per_tensor_quantized()
386+
and match.is_weight_perchannel_quantized()
387+
):
388+
make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor)
389+
390+
# No-op for unsupported quant patterns

backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,20 @@ void fp_accumulate_with_int8_weight(
4242
// Weight tile is indexed as w_tile.data[k4][n4][n4i]
4343
// -> gives packed integer containing the 4x 8-bit quantized values at index
4444
// (n, k), (n, k + 1), (n, k + 2), (n, k + 3)
45+
VEC4_T weight_texel;
4546
#if TILE_K4 == 1 && TILE_N4 == 1
46-
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
47-
VEC4_T unpacked_weight_k_row;
48-
// n = 0
49-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][0]);
50-
accum.data[m][0][0] += dot(in_tile.data[m][0], unpacked_weight_k_row);
51-
// n = 1
52-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][1]);
53-
accum.data[m][0][1] += dot(in_tile.data[m][0], unpacked_weight_k_row);
54-
// n = 2
55-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][2]);
56-
accum.data[m][0][2] += dot(in_tile.data[m][0], unpacked_weight_k_row);
57-
// n = 3
58-
unpacked_weight_k_row = unpack_packed_4xint8(w_tile.data[0][0][3]);
59-
accum.data[m][0][3] += dot(in_tile.data[m][0], unpacked_weight_k_row);
47+
[[unroll]] for (int k = 0; k < 4; ++k) {
48+
// Unpack one column of weights
49+
weight_texel = VEC4_T(
50+
extract_8bit_from_packed_int_le(w_tile.data[0][0][0], k),
51+
extract_8bit_from_packed_int_le(w_tile.data[0][0][1], k),
52+
extract_8bit_from_packed_int_le(w_tile.data[0][0][2], k),
53+
extract_8bit_from_packed_int_le(w_tile.data[0][0][3], k));
54+
55+
for (int m = 0; m < TILE_M; ++m) {
56+
accum.data[m][0] =
57+
fma(VEC4_T(in_tile.data[m][0][k]), weight_texel, accum.data[m][0]);
58+
}
6059
}
6160

6261
#else

backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,9 @@ std::vector<int64_t> calculate_input_im2col_sizes(
128128

129129
// K -> flattened convolution window (adjusted)
130130
const int64_t K = flattened_kernel_len * groups_val;
131-
// M -> number of elements in 2D output plane
132-
const int64_t M = out_height * out_width * batches;
131+
// M -> number of elements in 2D output plane. This is aligned to the next
132+
// multiple of 4 since the im2col shader operates on 4x4 blocks.
133+
const int64_t M = utils::align_up_4(out_height * out_width * batches);
133134

134135
return {M, K};
135136
}

backends/vulkan/runtime/graph/ops/impl/utils/QuantizationConfig.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ enum class QuantizationGranularity {
1616
PerChannel,
1717
PerTensor,
1818
PerGroup,
19-
None,
19+
NoQuantization,
2020
};
2121

2222
static constexpr QuantizationGranularity kPerChannel =
@@ -26,7 +26,7 @@ static constexpr QuantizationGranularity kPerTensor =
2626
static constexpr QuantizationGranularity kPerGroup =
2727
QuantizationGranularity::PerGroup;
2828
static constexpr QuantizationGranularity kNoQuantization =
29-
QuantizationGranularity::None;
29+
QuantizationGranularity::NoQuantization;
3030

3131
struct QuantizationConfig {
3232
int nbits;

backends/vulkan/test/custom_ops/q8csw_conv2d.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,19 +395,19 @@ std::vector<TestCase> generate_quantized_conv2d_test_cases() {
395395
std::to_string(config.kernel.w);
396396

397397
config.test_case_name = prefix + suffix;
398-
test_cases.push_back(
399-
create_test_case_from_config(config, storage_type, vkapi::kFloat));
398+
// The default operator tested is activation + weight quantized conv2d;
399+
// however, only test this if the int8 dot product extension is supported
400+
if (vkcompute::api::context()
401+
->adapter_ptr()
402+
->supports_int8_dot_product()) {
403+
test_cases.push_back(
404+
create_test_case_from_config(config, storage_type, vkapi::kFloat));
405+
}
400406

401407
Conv2dConfig wo_quant_config = config;
402408
wo_quant_config.op_name = "conv2d_q8csw";
403409
test_cases.push_back(create_test_case_from_config(
404410
wo_quant_config, storage_type, vkapi::kFloat));
405-
// Conv2dConfig config2 = config;
406-
// config2.shader_variant_name = "conv2d_q8csw_linear_tiled";
407-
// config2.name_suffix = prefix + suffix;
408-
// test_cases.push_back(
409-
// create_test_case_from_config(config2, storage_type,
410-
// vkapi::kFloat));
411411
}
412412
}
413413

@@ -778,7 +778,7 @@ int main(int argc, char* argv[]) {
778778
quantized_conv2d_flop_calculator,
779779
"QuantizedConv2d",
780780
0,
781-
1,
781+
10,
782782
ref_fn);
783783

784784
return 0;

backends/vulkan/test/custom_ops/q8csw_linear.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ std::vector<TestCase> generate_quantized_linear_easy_cases() {
151151
std::vector<TestCase> test_cases;
152152

153153
// Single simple configuration for debugging
154-
int M = 16;
155-
int K = 64;
156-
int N = 32;
154+
int M = 4;
155+
int K = 4;
156+
int N = 4;
157157

158158
LinearConfig config = {
159159
M, // Batch size
@@ -217,9 +217,13 @@ std::vector<TestCase> generate_quantized_linear_test_cases() {
217217
config.test_case_name = generated_test_case_name;
218218

219219
for (const auto& storage_type : storage_types) {
220-
// Test both activation+weight quantized and weight only quantized
221-
test_cases.push_back(
222-
create_test_case_from_config(config, storage_type, vkapi::kFloat));
220+
if (vkcompute::api::context()
221+
->adapter_ptr()
222+
->supports_int8_dot_product()) {
223+
// Test both activation+weight quantized and weight only quantized
224+
test_cases.push_back(
225+
create_test_case_from_config(config, storage_type, vkapi::kFloat));
226+
}
223227

224228
LinearConfig wo_quant_config = config;
225229
wo_quant_config.op_name = "linear_q8csw";
@@ -462,7 +466,6 @@ int main(int argc, char* argv[]) {
462466

463467
ReferenceComputeFunc ref_fn = reference_impl;
464468

465-
// Execute easy test cases using the new framework with custom FLOP calculator
466469
auto results = execute_test_cases(
467470
generate_quantized_linear_test_cases,
468471
quantized_linear_flop_calculator,

backends/vulkan/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,26 @@ def trace_args_until_placeholder(
11721172
return cur_node, traversed
11731173

11741174

1175+
def is_in_4bit_range(tensor: torch.Tensor) -> bool:
1176+
"""
1177+
Check if the given tensor is in the range of 4-bit quantization and is of integer type.
1178+
"""
1179+
if tensor.dtype not in (torch.int8, torch.uint8):
1180+
return False
1181+
1182+
return tensor.min().item() >= -8 and tensor.max().item() <= 7
1183+
1184+
1185+
def is_in_8bit_range(tensor: torch.Tensor) -> bool:
1186+
"""
1187+
Check if the given tensor is in the range of 4-bit quantization and is of integer type.
1188+
"""
1189+
if tensor.dtype not in (torch.int8, torch.uint8):
1190+
return False
1191+
1192+
return tensor.min().item() >= -128 and tensor.max().item() <= 127
1193+
1194+
11751195
##
11761196
## Misc
11771197
##

0 commit comments

Comments
 (0)