Skip to content

Commit 48ea65b

Browse files
committed
Update on "[ET-VK] Adding get or create int function to read int value."
This diff adds a new function `get_or_create_int` to the `ComputeGraph` class, which allows reading an integer value from a `ValueRef` index. The function returns the extracted integer value if the value at the index is an integer, otherwise it throws an error. Additionally, an overload of the function is added to return a default value if the value at the index is `None`. Differential Revision: [D78094858](https://our.internmc.facebook.com/intern/diff/D78094858/) [ghstack-poisoned]
2 parents 86aca86 + e8c538f commit 48ea65b

26 files changed

+553
-145
lines changed

.lintrunner.toml

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ exclude_patterns = [
1010
'exir/serde/**',
1111
]
1212
command = [
13-
'python3',
13+
'python',
1414
'-m',
1515
'lintrunner_adapters',
1616
'run',
@@ -19,7 +19,7 @@ command = [
1919
'@{{PATHSFILE}}'
2020
]
2121
init_command = [
22-
'python3',
22+
'python',
2323
'-m',
2424
'lintrunner_adapters',
2525
'run',
@@ -41,7 +41,7 @@ exclude_patterns = [
4141
'exir/serde/**',
4242
]
4343
command = [
44-
'python3',
44+
'python',
4545
'-m',
4646
'lintrunner_adapters',
4747
'run',
@@ -50,7 +50,7 @@ command = [
5050
'@{{PATHSFILE}}'
5151
]
5252
init_command = [
53-
'python3',
53+
'python',
5454
'-m',
5555
'lintrunner_adapters',
5656
'run',
@@ -84,7 +84,7 @@ exclude_patterns = [
8484
'runtime/core/portable_type/c10/**',
8585
]
8686
command = [
87-
'python3',
87+
'python',
8888
'-m',
8989
'lintrunner_adapters',
9090
'run',
@@ -95,7 +95,7 @@ command = [
9595
'@{{PATHSFILE}}'
9696
]
9797
init_command = [
98-
'python3',
98+
'python',
9999
'-m',
100100
'lintrunner_adapters',
101101
'run',
@@ -117,7 +117,7 @@ exclude_patterns = [
117117
'**/third-party/**',
118118
]
119119
command = [
120-
'python3',
120+
'python',
121121
'-m',
122122
'lintrunner_adapters',
123123
'run',
@@ -127,7 +127,7 @@ command = [
127127
'@{{PATHSFILE}}',
128128
]
129129
init_command = [
130-
'python3',
130+
'python',
131131
'-m',
132132
'lintrunner_adapters',
133133
'run',
@@ -151,7 +151,7 @@ exclude_patterns = [
151151
'**/third-party/**',
152152
]
153153
command = [
154-
'python3',
154+
'python',
155155
'-m',
156156
'lintrunner_adapters',
157157
'run',
@@ -192,7 +192,7 @@ exclude_patterns = [
192192
'extension/llm/custom_ops/spinquant/test/fast_hadamard_transform_special_unstrided_cpu.h',
193193
]
194194
command = [
195-
'python3',
195+
'python',
196196
'-m',
197197
'lintrunner_adapters',
198198
'run',
@@ -234,7 +234,7 @@ exclude_patterns = [
234234
'util/**',
235235
]
236236
command = [
237-
'python3',
237+
'python',
238238
'-m',
239239
'lintrunner_adapters',
240240
'run',
@@ -287,7 +287,7 @@ exclude_patterns = [
287287
'util/**',
288288
]
289289
command = [
290-
'python3',
290+
'python',
291291
'-m',
292292
'lintrunner_adapters',
293293
'run',
@@ -337,7 +337,7 @@ exclude_patterns = [
337337
'backends/arm/test/**',
338338
]
339339
command = [
340-
'python3',
340+
'python',
341341
'-m',
342342
'lintrunner_adapters',
343343
'run',
@@ -349,7 +349,7 @@ command = [
349349
'@{{PATHSFILE}}'
350350
]
351351
init_command = [
352-
'python3',
352+
'python',
353353
'-m',
354354
'lintrunner_adapters',
355355
'run',
@@ -368,7 +368,7 @@ exclude_patterns = [
368368
'.lintrunner.toml',
369369
]
370370
command = [
371-
'python3',
371+
'python',
372372
'-m',
373373
'lintrunner_adapters',
374374
'run',
@@ -397,7 +397,7 @@ exclude_patterns = [
397397
]
398398

399399
command = [
400-
"python3",
400+
"python",
401401
"-m",
402402
"lintrunner_adapters",
403403
"run",

backends/cadence/aot/pass_utils.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -174,30 +174,53 @@ def nodes_not_adjacent_in_gm(
174174

175175
def get_arg(
176176
node: torch.fx.Node,
177-
arg_index: int,
178177
kwarg_name: str,
179-
*,
180-
default: torch.fx.node.Argument = None,
181178
) -> torch.fx.node.Argument:
182179
"""
183-
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
184-
return default.
180+
Get the arg with arg_name of the node, returns default value if not set.
185181
"""
186-
if arg_index < len(node.args):
187-
return node.args[arg_index]
188-
elif kwarg_name in node.kwargs:
182+
# Try to get the arg from kwargs first since this is faster
183+
if kwarg_name in node.kwargs:
189184
return node.kwargs[kwarg_name]
190-
else:
191-
return default
185+
186+
# If it's not found in kwargs, try to normalize the args
187+
normalized_args = node.normalized_arguments(
188+
node.graph.owning_module, normalize_to_only_use_kwargs=True
189+
)
190+
if not normalized_args:
191+
raise RuntimeError(
192+
f"get_arg: Node {node} does not support normalization of arguments"
193+
)
194+
195+
return normalized_args.kwargs[kwarg_name]
192196

193197

194198
def set_arg(
195-
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
199+
node: torch.fx.Node, kwarg_name: str, value: torch.fx.node.Argument
196200
) -> None:
197201
"""
198-
Set the arg at arg_index if it exists, otherwise set the kwarg.
202+
Set the node's arg with its name to the given value.
199203
"""
200-
if arg_index < len(node.args):
201-
node.update_arg(arg_index, value)
204+
# Try to set the arg if it is present in kwargs first since this is faster
205+
if kwarg_name in node.kwargs:
206+
node.update_kwarg(kwarg_name, value)
207+
return
208+
209+
# If it's not found in kwargs, try to normalize the args and set the arg
210+
normalized_args = node.normalized_arguments(
211+
node.graph.owning_module, normalize_to_only_use_kwargs=True
212+
)
213+
if not normalized_args:
214+
raise RuntimeError(
215+
f"set_arg: Node {node} does not support normalization of arguments"
216+
)
217+
218+
kwargs = normalized_args.kwargs
219+
if kwarg_name not in kwargs:
220+
raise ValueError(f"set_arg: invalid arg name {kwarg_name} for node {node} used")
221+
222+
idx = list(kwargs.keys()).index(kwarg_name)
223+
if idx < len(node.args):
224+
node.update_arg(idx, value)
202225
else:
203226
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -779,17 +779,17 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
779779
for slice_copy_node in graph_module.graph.find_nodes(
780780
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
781781
):
782-
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
783-
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
784-
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
785-
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
786-
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
782+
cat_node = cast(Node, get_arg(slice_copy_node, "input"))
783+
slice_dim = cast(int, get_arg(slice_copy_node, "dim"))
784+
start_idx = cast(int, get_arg(slice_copy_node, "start"))
785+
end_idx = cast(int, get_arg(slice_copy_node, "end"))
786+
step = cast(int, get_arg(slice_copy_node, "step"))
787787

788788
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
789789
continue
790790

791791
# Make sure cat and slice happens on the same dimension.
792-
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
792+
cat_dim = cast(Node, get_arg(cat_node, "dim"))
793793
if cat_dim != slice_dim:
794794
continue
795795

@@ -805,14 +805,14 @@ def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
805805
end_idx += cat_output_shape[cat_dim]
806806

807807
offset = 0
808-
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
808+
for cat_input_node in cast(List[Node], get_arg(cat_node, "tensors")):
809809
cat_input_shape = cat_input_node.meta["val"].shape
810810

811811
# Check if the slice range overlaps with the cat input range.
812812
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
813813
slice_copy_node.replace_input_with(cat_node, cat_input_node)
814-
set_arg(slice_copy_node, 2, "start", start_idx - offset)
815-
set_arg(slice_copy_node, 3, "end", end_idx - offset)
814+
set_arg(slice_copy_node, "start", start_idx - offset)
815+
set_arg(slice_copy_node, "end", end_idx - offset)
816816
break
817817

818818
offset += cat_input_shape[cat_dim]

backends/vulkan/op_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ def register_ephemeral_op(features: OpFeatures):
259259
exir_ops.edge.aten.div.Tensor,
260260
exir_ops.edge.aten.div.Tensor_mode,
261261
exir_ops.edge.aten.pow.Tensor_Tensor,
262+
exir_ops.edge.aten.eq.Tensor,
263+
exir_ops.edge.aten.lt.Tensor,
264+
exir_ops.edge.aten.le.Tensor,
265+
exir_ops.edge.aten.gt.Tensor,
266+
exir_ops.edge.aten.ge.Tensor,
262267
]
263268
)
264269
def register_binary_op(features: OpFeatures):

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,16 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
728728
)
729729

730730
for variant in params_dict["shader_variants"]:
731+
default_iterated_params_names = set(
732+
default_iterated_params.keys()
733+
if default_iterated_params is not None
734+
else {}
735+
)
731736
variant_params_names = set(variant.keys())
737+
732738
invalid_keys = (
733739
variant_params_names
740+
- default_iterated_params_names
734741
- params_names
735742
- {"generate_variant_forall"}
736743
)
@@ -758,6 +765,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
758765
variant_name = f"{variant_name}_{param_value[1]}"
759766

760767
default_params_copy["NAME"] = variant_name
768+
default_params_copy["VARIANT_NAME"] = variant["NAME"]
761769

762770
self.shader_template_params[template_name].append(
763771
default_params_copy

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -549,20 +549,13 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
549549
}
550550
}
551551

552-
int32_t ComputeGraph::get_or_create_int(const ValueRef idx) {
553-
if (values_.at(idx).isInt()) {
554-
return extract_scalar<int32_t>(idx);
555-
}
556-
VK_THROW("Cannot create a int param buffer for the given value");
557-
}
558-
559552
int32_t ComputeGraph::get_or_create_int(
560553
const ValueRef idx,
561554
const int32_t default_val) {
562555
if (values_.at(idx).isNone()) {
563556
return default_val;
564557
}
565-
return get_or_create_int(idx);
558+
return extract_scalar<int32_t>(idx);
566559
}
567560

568561
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,6 @@ class ComputeGraph final {
685685
const ValueRef idx,
686686
const int32_t default_value);
687687

688-
int32_t get_or_create_int(const ValueRef idx);
689-
690688
int32_t get_or_create_int(const ValueRef idx, const int32_t default_value);
691689

692690
void set_symint(const ValueRef idx, const int32_t val);

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,35 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
// Binary comparison ops require that the output is boolean and not the same as input.
14+
$IS_COMPARISON_OP = (any([name in VARIANT_NAME for name in ["binary_eq", "binary_lt", "binary_le", "binary_gt", "binary_ge"]]))
15+
16+
#define NAME ${VARIANT_NAME}
17+
1318
#define VEC4_T ${texel_type(DTYPE)}
14-
#define T ${buffer_scalar_type(DTYPE)}
19+
$if IS_COMPARISON_OP:
20+
#define T ${buffer_scalar_type("uint8")}
21+
#define VEC4_OUT_T ${texel_type("uint8")}
22+
$else:
23+
#define T ${buffer_scalar_type(DTYPE)}
24+
#define VEC4_OUT_T VEC4_T
1525

1626
#define op(X, Y, A) ${OPERATOR}
1727

1828
${define_active_storage_type(STORAGE)}
1929
${define_required_extensions(DTYPE)}
2030

31+
32+
$if IS_COMPARISON_OP:
33+
${define_required_extensions("uint8")}
34+
2135
layout(std430) buffer;
2236

23-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
37+
$if IS_COMPARISON_OP:
38+
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
39+
$else:
40+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
41+
2442
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2543
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
2644

@@ -121,7 +139,7 @@ void main() {
121139
write_texel_lpos(
122140
t_out,
123141
lpos,
124-
VEC4_T(op(in_texel, other_texel, alpha)),
142+
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
125143
out_axis_map);
126144
}
127145

0 commit comments

Comments
 (0)