Skip to content

Commit a0b1661

Browse files
committed
Update on "[ET-VK] Minor performance improvements for buffer to int8 quantized packing."
This diff provides minor performance improvements for buffer to int8 quantized packing in the Vulkan runtime graph ops. Differential Revision: [D74616519](https://our.internmc.facebook.com/intern/diff/D74616519/) [ghstack-poisoned]
2 parents c6cad4f + 02c4b0d commit a0b1661

File tree

23 files changed

+546
-117
lines changed

23 files changed

+546
-117
lines changed

.lintrunner.toml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ exclude_patterns = [
1010
'exir/serde/**',
1111
]
1212
command = [
13-
'python3',
13+
'python',
1414
'-m',
1515
'lintrunner_adapters',
1616
'run',
@@ -19,7 +19,7 @@ command = [
1919
'@{{PATHSFILE}}'
2020
]
2121
init_command = [
22-
'python3',
22+
'python',
2323
'-m',
2424
'lintrunner_adapters',
2525
'run',
@@ -41,7 +41,7 @@ exclude_patterns = [
4141
'exir/serde/**',
4242
]
4343
command = [
44-
'python3',
44+
'python',
4545
'-m',
4646
'lintrunner_adapters',
4747
'run',
@@ -50,7 +50,7 @@ command = [
5050
'@{{PATHSFILE}}'
5151
]
5252
init_command = [
53-
'python3',
53+
'python',
5454
'-m',
5555
'lintrunner_adapters',
5656
'run',
@@ -84,7 +84,7 @@ exclude_patterns = [
8484
'runtime/core/portable_type/c10/**',
8585
]
8686
command = [
87-
'python3',
87+
'python',
8888
'-m',
8989
'lintrunner_adapters',
9090
'run',
@@ -95,7 +95,7 @@ command = [
9595
'@{{PATHSFILE}}'
9696
]
9797
init_command = [
98-
'python3',
98+
'python',
9999
'-m',
100100
'lintrunner_adapters',
101101
'run',
@@ -117,7 +117,7 @@ exclude_patterns = [
117117
'**/third-party/**',
118118
]
119119
command = [
120-
'python3',
120+
'python',
121121
'-m',
122122
'lintrunner_adapters',
123123
'run',
@@ -127,7 +127,7 @@ command = [
127127
'@{{PATHSFILE}}',
128128
]
129129
init_command = [
130-
'python3',
130+
'python',
131131
'-m',
132132
'lintrunner_adapters',
133133
'run',
@@ -151,7 +151,7 @@ exclude_patterns = [
151151
'**/third-party/**',
152152
]
153153
command = [
154-
'python3',
154+
'python',
155155
'-m',
156156
'lintrunner_adapters',
157157
'run',
@@ -192,7 +192,7 @@ exclude_patterns = [
192192
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',
193193
]
194194
command = [
195-
'python3',
195+
'python',
196196
'-m',
197197
'lintrunner_adapters',
198198
'run',
@@ -234,7 +234,7 @@ exclude_patterns = [
234234
'util/**',
235235
]
236236
command = [
237-
'python3',
237+
'python',
238238
'-m',
239239
'lintrunner_adapters',
240240
'run',
@@ -287,7 +287,7 @@ exclude_patterns = [
287287
'util/**',
288288
]
289289
command = [
290-
'python3',
290+
'python',
291291
'-m',
292292
'lintrunner_adapters',
293293
'run',
@@ -337,7 +337,7 @@ exclude_patterns = [
337337
'backends/arm/test/**',
338338
]
339339
command = [
340-
'python3',
340+
'python',
341341
'-m',
342342
'lintrunner_adapters',
343343
'run',
@@ -349,7 +349,7 @@ command = [
349349
'@{{PATHSFILE}}'
350350
]
351351
init_command = [
352-
'python3',
352+
'python',
353353
'-m',
354354
'lintrunner_adapters',
355355
'run',
@@ -368,7 +368,7 @@ exclude_patterns = [
368368
'.lintrunner.toml',
369369
]
370370
command = [
371-
'python3',
371+
'python',
372372
'-m',
373373
'lintrunner_adapters',
374374
'run',
@@ -397,7 +397,7 @@ exclude_patterns = [
397397
]
398398

399399
command = [
400-
"python3",
400+
"python",
401401
"-m",
402402
"lintrunner_adapters",
403403
"run",

backends/cadence/aot/pass_utils.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(
174174

175175
def get_arg(
176176
node: torch.fx.Node,
177-
arg_index: int,
178177
kwarg_name: str,
179-
*,
180-
default: torch.fx.node.Argument = None,
181178
) -> torch.fx.node.Argument:
182179
"""
183-
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
184-
return default.
180+
Get the arg with arg_name of the node, returns default value if not set.
185181
"""
186-
if arg_index < len(node.args):
187-
return node.args[arg_index]
188-
elif kwarg_name in node.kwargs:
182+
# Try to get the arg from kwargs first since this is faster
183+
if kwarg_name in node.kwargs:
189184
return node.kwargs[kwarg_name]
190-
else:
191-
return default
185+
186+
# If it's not found in kwargs, try to normalize the args
187+
normalized_args = node.normalized_arguments(
188+
node.graph.owning_module, normalize_to_only_use_kwargs=True
189+
)
190+
if not normalized_args:
191+
raise RuntimeError(
192+
f"get_arg: Node {node} does not support normalization of arguments"
193+
)
194+
195+
return normalized_args.kwargs[kwarg_name]
192196

193197

194198
def set_arg(
195-
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
199+
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
196200
) -> None:
197201
"""
198-
Set the arg at arg_index if it exists, otherwise set the kwarg.
202+
Set the node's arg with its name to the given value.
199203
"""
200-
if arg_index < len(node.args):
201-
node.update_arg(arg_index, value)
204+
# Try to set the arg if it is present in kwargs first since this is faster
205+
if kwarg_name in node.kwargs:
206+
node.update_kwarg(kwarg_name, value)
207+
return
208+
209+
# If it's not found in kwargs, try to normalize the args and set the arg
210+
normalized_args = node.normalized_arguments(
211+
node.graph.owning_module, normalize_to_only_use_kwargs=True
212+
)
213+
if not normalized_args:
214+
raise RuntimeError(
215+
f"set_arg: Node {node} does not support normalization of arguments"
216+
)
217+
218+
kwargs = normalized_args.kwargs
219+
if kwarg_name not in kwargs:
220+
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")
221+
222+
idx = list(kwargs.keys()).index(kwarg_name)
223+
if idx < len(node.args):
224+
node.update_arg(idx, value)
202225
else:
203226
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
779779
for slice_copy_node in graph_module.graph.find_nodes(
780780
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
781781
):
782-
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
783-
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
784-
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
785-
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
786-
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
782+
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
783+
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
784+
start_idx = cast(int, get_arg(slice_copy_node, "start"))
785+
end_idx = cast(int, get_arg(slice_copy_node, "end"))
786+
step = cast(int, get_arg(slice_copy_node, "step"))
787787

788788
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
789789
continue
790790

791791
# Make sure cat and slice happens on the same dimension.
792-
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
792+
cat_dim = cast(Node, get_arg(cat_node, "dim"))
793793
if cat_dim != slice_dim:
794794
continue
795795

@@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
805805
end_idx += cat_output_shape[cat_dim]
806806

807807
offset = 0
808-
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
808+
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
809809
cat_input_shape = cat_input_node.meta["val"].shape
810810

811811
# Check if the slice range overlaps with the cat input range.
812812
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
813813
slice_copy_node.replace_input_with(cat_node, cat_input_node)
814-
set_arg(slice_copy_node, 2, "start", start_idx - offset)
815-
set_arg(slice_copy_node, 3, "end", end_idx - offset)
814+
set_arg(slice_copy_node, "start", start_idx - offset)
815+
set_arg(slice_copy_node, "end", end_idx - offset)
816816
break
817817

818818
offset += cat_input_shape[cat_dim]

backends/vulkan/op_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ def register_ephemeral_op(features: OpFeatures):
259259
exir_ops.edge.aten.div.Tensor,
260260
exir_ops.edge.aten.div.Tensor_mode,
261261
exir_ops.edge.aten.pow.Tensor_Tensor,
262+
exir_ops.edge.aten.eq.Tensor,
263+
exir_ops.edge.aten.lt.Tensor,
264+
exir_ops.edge.aten.le.Tensor,
265+
exir_ops.edge.aten.gt.Tensor,
266+
exir_ops.edge.aten.ge.Tensor,
262267
]
263268
)
264269
def register_binary_op(features: OpFeatures):

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,16 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
728728
)
729729

730730
for variant in params_dict["shader_variants"]:
731+
default_iterated_params_names = set(
732+
default_iterated_params.keys()
733+
if default_iterated_params is not None
734+
else {}
735+
)
731736
variant_params_names = set(variant.keys())
737+
732738
invalid_keys = (
733739
variant_params_names
740+
- default_iterated_params_names
734741
- params_names
735742
- {"generate_variant_forall"}
736743
)
@@ -758,6 +765,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
758765
variant_name = f"{variant_name}_{param_value[1]}"
759766

760767
default_params_copy["NAME"] = variant_name
768+
default_params_copy["VARIANT_NAME"] = variant["NAME"]
761769

762770
self.shader_template_params[template_name].append(
763771
default_params_copy

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,35 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
// Binary comparison ops require that the output is boolean and not the same as input.
14+
$IS_COMPARISON_OP = (any([name in VARIANT_NAME for name in ["binary_eq", "binary_lt", "binary_le", "binary_gt", "binary_ge"]]))
15+
16+
#define NAME ${VARIANT_NAME}
17+
1318
#define VEC4_T ${texel_type(DTYPE)}
14-
#define T ${buffer_scalar_type(DTYPE)}
19+
$if IS_COMPARISON_OP:
20+
#define T ${buffer_scalar_type("uint8")}
21+
#define VEC4_OUT_T ${texel_type("uint8")}
22+
$else:
23+
#define T ${buffer_scalar_type(DTYPE)}
24+
#define VEC4_OUT_T VEC4_T
1525

1626
#define op(X, Y, A) ${OPERATOR}
1727

1828
${define_active_storage_type(STORAGE)}
1929
${define_required_extensions(DTYPE)}
2030

31+
32+
$if IS_COMPARISON_OP:
33+
${define_required_extensions("uint8")}
34+
2135
layout(std430) buffer;
2236

23-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
37+
$if IS_COMPARISON_OP:
38+
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
39+
$else:
40+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
41+
2442
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2543
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
2644

@@ -121,7 +139,7 @@ void main() {
121139
write_texel_lpos(
122140
t_out,
123141
lpos,
124-
VEC4_T(op(in_texel, other_texel, alpha)),
142+
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
125143
out_axis_map);
126144
}
127145

0 commit comments

Comments
 (0)