Skip to content

Commit a44f68d

Browse files
authored
[ET-VK] Rename quantize/dequantize functions/shaders to be more generic (#15752)
Title says it all! Currently, quantize/dequantize ops are named like `add_quantize_and_pack_linear_input_node` `add_quantize_and_pack_q8ta_conv2d_input_node` This diff renames them to `add_quantize_and_pack_4h4w_node` `add_quantize_and_pack_4w4c_node` which references the memory layout they produce rather than a specific op. Differential Revision: [D86702456](https://our.internmc.facebook.com/intern/diff/D86702456/)
1 parent eace225 commit a44f68d

13 files changed

+52
-51
lines changed

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.glsl

File renamed without changes.

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w.yaml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,22 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
quantize_and_pack_linear_input:
7+
quantize_and_pack_4h4w:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
OUTPUT_STORAGE: texture3d
1111
INPUT_STORAGE: texture3d
1212
STORAGE: texture3d
1313
GRANULARITY: per_tensor
1414
generate_variant_forall:
15+
combination:
16+
parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE]
17+
combos:
18+
- parameter_values: [texture3d, texture3d]
19+
- parameter_values: [buffer, texture3d]
20+
- parameter_values: [buffer, buffer]
1521
DTYPE:
1622
- VALUE: half
1723
- VALUE: float
1824
shader_variants:
19-
- NAME: quantize_and_pack_linear_input_per_tensor_texture3d_texture3d
20-
- NAME: quantize_and_pack_linear_input_per_tensor_buffer_texture3d
21-
OUTPUT_STORAGE: buffer
22-
- NAME: quantize_and_pack_linear_input_per_tensor_buffer_buffer
23-
OUTPUT_STORAGE: buffer
24-
INPUT_STORAGE: buffer
25+
- NAME: quantize_and_pack_4h4w_per_tensor

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.glsl

File renamed without changes.

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4h4w_with_group_sums.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
quantize_and_pack_linear_input_with_sums:
7+
quantize_and_pack_4h4w_with_group_sums:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
OUTPUT_STORAGE: buffer
@@ -16,14 +16,14 @@ quantize_and_pack_linear_input_with_sums:
1616
- VALUE: half
1717
- VALUE: float
1818
shader_variants:
19-
- NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_texture3d
20-
- NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_buffer
19+
- NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_texture3d
20+
- NAME: quantize_and_pack_4h4w_with_group_sums_o2w32_buffer_buffer
2121
OUTPUT_STORAGE: buffer
2222
INPUT_STORAGE: buffer
23-
- NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_texture3d
23+
- NAME: quantize_and_pack_4h4w_with_group_sums_o4w16_buffer_texture3d
2424
NUM_GROUPS_PER_WG: 4
2525
NUM_WORKERS_PER_GROUP: 16
26-
- NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_buffer
26+
- NAME: quantize_and_pack_4h4w_with_group_sums_o4w16_buffer_buffer
2727
NUM_GROUPS_PER_WG: 4
2828
NUM_WORKERS_PER_GROUP: 16
2929
OUTPUT_STORAGE: buffer

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.glsl

File renamed without changes.

backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_4w4c.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
quantize_and_pack_q8ta_conv2d_input:
7+
quantize_and_pack_4w4c:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
OUTPUT_STORAGE: texture3d
@@ -19,4 +19,4 @@ quantize_and_pack_q8ta_conv2d_input:
1919
DTYPE:
2020
- VALUE: float
2121
shader_variants:
22-
- NAME: quantize_and_pack_q8ta_conv2d_input
22+
- NAME: quantize_and_pack_4w4c_per_tensor

backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.glsl

File renamed without changes.

backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/unpack_4w4c_and_dequantize.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
unpack_and_dequantize_q8ta_conv2d_output:
7+
unpack_4w4c_and_dequantize:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
OUTPUT_STORAGE: texture3d
@@ -19,4 +19,4 @@ unpack_and_dequantize_q8ta_conv2d_output:
1919
DTYPE:
2020
- VALUE: float
2121
shader_variants:
22-
- NAME: unpack_and_dequantize_q8ta_conv2d_output
22+
- NAME: unpack_4w4c_and_dequantize_per_tensor

backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ std::tuple<int64_t, int64_t> get_quantized_input_num_blocks(
4141
return std::make_tuple(num_blocks_M, num_blocks_K);
4242
}
4343

44-
utils::uvec3 quant_pack_input_global_wg_size(
44+
utils::uvec3 quantize_and_pack_4h4w_global_wg_size(
4545
ComputeGraph* graph,
4646
const vkapi::ShaderInfo& shader,
4747
const std::vector<ArgGroup>& args,
@@ -57,7 +57,7 @@ utils::uvec3 quant_pack_input_global_wg_size(
5757
1u};
5858
}
5959

60-
vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader(
60+
vkapi::ShaderInfo pick_quantize_and_pack_4h4w_with_group_sums_shader(
6161
ComputeGraph* graph,
6262
const std::vector<ArgGroup>& args,
6363
const std::vector<ValueRef>& resize_args) {
@@ -67,7 +67,7 @@ vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader(
6767

6868
const int64_t group_size_val = graph->extract_scalar<int64_t>(group_size);
6969

70-
std::string shader_name = "quantize_and_pack_linear_input_with_sums";
70+
std::string shader_name = "quantize_and_pack_4h4w_with_group_sums";
7171
if (group_size_val >= 128) {
7272
shader_name += "_o2w32";
7373
} else {
@@ -82,7 +82,7 @@ vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader(
8282
return VK_KERNEL_FROM_STR(shader_name);
8383
}
8484

85-
utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size(
85+
utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size(
8686
ComputeGraph* graph,
8787
const vkapi::ShaderInfo& shader,
8888
const std::vector<ArgGroup>& args,
@@ -113,7 +113,7 @@ utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size(
113113
1u};
114114
}
115115

116-
utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size(
116+
utils::uvec3 pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size(
117117
ComputeGraph* graph,
118118
const vkapi::ShaderInfo& shader,
119119
const utils::uvec3& global_workgroup_size,
@@ -144,7 +144,7 @@ utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size(
144144
// Dispatch logic (Linear)
145145
//
146146

147-
void add_quantize_and_pack_linear_input_node(
147+
void add_quantize_and_pack_4h4w_node(
148148
ComputeGraph& graph,
149149
const QuantizationConfig& input_quant_config,
150150
const ValueRef fp_input,
@@ -164,7 +164,7 @@ void add_quantize_and_pack_linear_input_node(
164164
float inv_scale = 1.0f / graph.extract_scalar<float>(input_scale_data);
165165
int32_t zp = graph.extract_scalar<int32_t>(input_zp_data);
166166

167-
std::string shader_name = "quantize_and_pack_linear_input_per_tensor";
167+
std::string shader_name = "quantize_and_pack_4h4w_per_tensor";
168168
add_storage_type_suffix(shader_name, graph.storage_type_of(packed_int_input));
169169
add_storage_type_suffix(shader_name, graph.storage_type_of(fp_input));
170170
add_dtype_suffix(shader_name, graph.dtype_of(fp_input));
@@ -179,7 +179,7 @@ void add_quantize_and_pack_linear_input_node(
179179
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
180180
graph,
181181
VK_KERNEL_FROM_STR(shader_name),
182-
quant_pack_input_global_wg_size,
182+
quantize_and_pack_4h4w_global_wg_size,
183183
default_pick_local_wg_size,
184184
// Inputs and Outputs
185185
{{packed_int_input, vkapi::kWrite}, {fp_input, vkapi::kRead}},
@@ -193,7 +193,7 @@ void add_quantize_and_pack_linear_input_node(
193193
{}));
194194
}
195195

196-
void add_quantize_and_pack_linear_input_with_sums_node(
196+
void add_quantize_and_pack_4h4w_with_group_sums_node(
197197
ComputeGraph& graph,
198198
const QuantizationConfig& input_quant_config,
199199
const ValueRef fp_input,
@@ -216,9 +216,9 @@ void add_quantize_and_pack_linear_input_with_sums_node(
216216

217217
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
218218
graph,
219-
pick_quantize_and_pack_input_with_sums_shader,
220-
pick_quantize_and_pack_input_with_sums_global_wg_size,
221-
pick_quantize_and_pack_input_with_sums_local_wg_size,
219+
pick_quantize_and_pack_4h4w_with_group_sums_shader,
220+
pick_quantize_and_pack_4h4w_with_group_sums_global_wg_size,
221+
pick_quantize_and_pack_4h4w_with_group_sums_local_wg_size,
222222
// Inputs and Outputs
223223
{{{packed_int_input, int_input_sums}, vkapi::kWrite},
224224
{{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead}},
@@ -236,7 +236,7 @@ void add_quantize_and_pack_linear_input_with_sums_node(
236236
// Dispatch utilities (Conv2d)
237237
//
238238

239-
utils::uvec3 pick_quantize_and_pack_conv2d_input_global_wg_size(
239+
utils::uvec3 pick_quantize_and_pack_4w4c_global_wg_size(
240240
ComputeGraph* graph,
241241
const vkapi::ShaderInfo& shader,
242242
const std::vector<ArgGroup>& args,
@@ -253,7 +253,7 @@ utils::uvec3 pick_quantize_and_pack_conv2d_input_global_wg_size(
253253
return {W4, H, C4};
254254
}
255255

256-
utils::uvec3 pick_unpack_and_dequantize_conv2d_output_global_wg_size(
256+
utils::uvec3 pick_unpack_4w4c_and_dequantize_global_wg_size(
257257
ComputeGraph* graph,
258258
const vkapi::ShaderInfo& shader,
259259
const std::vector<ArgGroup>& args,
@@ -274,7 +274,7 @@ utils::uvec3 pick_unpack_and_dequantize_conv2d_output_global_wg_size(
274274
// Dispatch logic (Conv2d)
275275
//
276276

277-
void add_quantize_and_pack_q8ta_conv2d_input_node(
277+
void add_quantize_and_pack_4w4c_node(
278278
ComputeGraph& graph,
279279
const ValueRef fp_input,
280280
const ValueRef input_scale,
@@ -284,7 +284,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node(
284284
int32_t zp = graph.extract_scalar<int32_t>(input_zp);
285285

286286
// Get shader for quantized conv2d linear tiled
287-
std::string kernel_name = "quantize_and_pack_q8ta_conv2d_input";
287+
std::string kernel_name = "quantize_and_pack_4w4c_per_tensor";
288288
add_storage_type_suffix(
289289
kernel_name, graph.storage_type_of(packed_int8_input));
290290
add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_input));
@@ -302,7 +302,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node(
302302
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
303303
graph,
304304
VK_KERNEL_FROM_STR(kernel_name),
305-
pick_quantize_and_pack_conv2d_input_global_wg_size,
305+
pick_quantize_and_pack_4w4c_global_wg_size,
306306
pick_wc_square_wg_size,
307307
// Inputs and Outputs
308308
{{packed_int8_input, vkapi::kWrite}, {fp_input, vkapi::kRead}},
@@ -318,7 +318,7 @@ void add_quantize_and_pack_q8ta_conv2d_input_node(
318318
nullptr));
319319
}
320320

321-
void add_unpack_and_dequantize_q8ta_conv2d_output_node(
321+
void add_unpack_4w4c_and_dequantize_node(
322322
ComputeGraph& graph,
323323
const ValueRef packed_int8_output,
324324
const ValueRef output_scale,
@@ -328,7 +328,7 @@ void add_unpack_and_dequantize_q8ta_conv2d_output_node(
328328
int32_t zp = graph.extract_scalar<int32_t>(output_zp);
329329

330330
// Get shader for quantized conv2d linear tiled
331-
std::string kernel_name = "unpack_and_dequantize_q8ta_conv2d_output";
331+
std::string kernel_name = "unpack_4w4c_and_dequantize_per_tensor";
332332
add_storage_type_suffix(kernel_name, graph.storage_type_of(fp_output));
333333
add_storage_type_suffix(
334334
kernel_name, graph.storage_type_of(packed_int8_output));
@@ -346,7 +346,7 @@ void add_unpack_and_dequantize_q8ta_conv2d_output_node(
346346
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
347347
graph,
348348
VK_KERNEL_FROM_STR(kernel_name),
349-
pick_unpack_and_dequantize_conv2d_output_global_wg_size,
349+
pick_unpack_4w4c_and_dequantize_global_wg_size,
350350
default_pick_local_wg_size,
351351
// Inputs and Outputs
352352
{{fp_output, vkapi::kWrite}, {packed_int8_output, vkapi::kRead}},
@@ -375,7 +375,7 @@ void quantize_q8ta_for_conv2d(
375375
const ValueRef zero_point = args.at(idx++);
376376
const ValueRef packed_int8_input = args.at(idx++);
377377

378-
add_quantize_and_pack_q8ta_conv2d_input_node(
378+
add_quantize_and_pack_4w4c_node(
379379
graph, fp_input, scale, zero_point, packed_int8_input);
380380
}
381381

@@ -388,7 +388,7 @@ void dequantize_q8to_from_conv2d(
388388
const ValueRef zero_point = args.at(idx++);
389389
const ValueRef fp_output = args.at(idx++);
390390

391-
add_unpack_and_dequantize_q8ta_conv2d_output_node(
391+
add_unpack_4w4c_and_dequantize_node(
392392
graph, packed_int8_output, scale, zero_point, fp_output);
393393
}
394394

@@ -408,10 +408,10 @@ void qdq8ta_conv2d_input(
408408
utils::kBuffer,
409409
utils::kPackedInt8_4W4C);
410410

411-
add_quantize_and_pack_q8ta_conv2d_input_node(
411+
add_quantize_and_pack_4w4c_node(
412412
graph, fp_input, scale, zero_point, packed_int8_input);
413413

414-
add_unpack_and_dequantize_q8ta_conv2d_output_node(
414+
add_unpack_4w4c_and_dequantize_node(
415415
graph, packed_int8_input, scale, zero_point, fp_output);
416416
}
417417

backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ bool is_gemv(ComputeGraph* graph, const ValueRef& fp_input);
2323
// Quantize, Dequantize for Linear/Matmul
2424
//
2525

26-
void add_quantize_and_pack_linear_input_node(
26+
void add_quantize_and_pack_4h4w_node(
2727
ComputeGraph& graph,
2828
const QuantizationConfig& input_quant_config,
2929
const ValueRef fp_input,
@@ -34,7 +34,7 @@ void add_quantize_and_pack_linear_input_node(
3434
const ValueRef packed_int_input,
3535
const ValueRef group_size);
3636

37-
void add_quantize_and_pack_linear_input_with_sums_node(
37+
void add_quantize_and_pack_4h4w_with_group_sums_node(
3838
ComputeGraph& graph,
3939
const QuantizationConfig& input_quant_config,
4040
const ValueRef fp_input,
@@ -48,14 +48,14 @@ void add_quantize_and_pack_linear_input_with_sums_node(
4848
// Quantize, Dequantize for Convolution
4949
//
5050

51-
void add_quantize_and_pack_q8ta_conv2d_input_node(
51+
void add_quantize_and_pack_4w4c_node(
5252
ComputeGraph& graph,
5353
const ValueRef fp_input,
5454
const ValueRef input_scale,
5555
const ValueRef input_zp,
5656
const ValueRef packed_int8_input);
5757

58-
void add_unpack_and_dequantize_q8ta_conv2d_output_node(
58+
void add_unpack_4w4c_and_dequantize_node(
5959
ComputeGraph& graph,
6060
const ValueRef packed_int8_output,
6161
const ValueRef output_scale,

0 commit comments

Comments
 (0)