Skip to content

Commit a1d934f

Browse files
andreanicastrofacebook-github-bot
authored andcommitted
Binary Comparison Ops
Summary: This change introduces the `binary_eq`, `binary_lt`, `binary_le`, `binary_gt`, `binary_ge` operators. 1. Introduced the operators in the binary_op.yaml file 2. we now store the shader variant base name to better handle special cases in the glsl of binary op 3. binary ops assumed that the output of the operation is the same of the input data. For `binary_eq` this is not true. This is now handled both in the shader and the in the graph creation. 4. added test case Reviewed By: SS-JIA Differential Revision: D76049244
1 parent d533a87 commit a1d934f

File tree

8 files changed

+181
-5
lines changed

8 files changed

+181
-5
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def register_ephemeral_op(features: OpFeatures):
259259
exir_ops.edge.aten.div.Tensor,
260260
exir_ops.edge.aten.div.Tensor_mode,
261261
exir_ops.edge.aten.pow.Tensor_Tensor,
262+
exir_ops.edge.aten.eq.Tensor,
262263
]
263264
)
264265
def register_binary_op(features: OpFeatures):

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,9 +728,19 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
728728
)
729729

730730
for variant in params_dict["shader_variants"]:
731+
default_iterated_params_names = set(
732+
default_iterated_params.keys()
733+
if default_iterated_params is not None
734+
else {}
735+
)
731736
variant_params_names = set(variant.keys())
737+
738+
print(
739+
f"default_iterated_params_names: {default_iterated_params_names}"
740+
)
732741
invalid_keys = (
733742
variant_params_names
743+
- default_iterated_params_names
734744
- params_names
735745
- {"generate_variant_forall"}
736746
)
@@ -758,6 +768,7 @@ def parseTemplateYaml(self, yaml_file: str) -> None:
758768
variant_name = f"{variant_name}_{param_value[1]}"
759769

760770
default_params_copy["NAME"] = variant_name
771+
default_params_copy["VARIANT_NAME"] = variant["NAME"]
761772

762773
self.shader_template_params[template_name].append(
763774
default_params_copy

backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,35 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
// Binary comparison ops require that the output is boolean and not the same as input.
14+
$IS_COMPARISON_OP = ([name in VARIANT_NAME for name in ["binary_eq", "binary_lt", "binary_le", "binary_gt", "binary_ge"]])
15+
16+
#define NAME ${VARIANT_NAME}
17+
1318
#define VEC4_T ${texel_type(DTYPE)}
14-
#define T ${buffer_scalar_type(DTYPE)}
19+
$if IS_COMPARISON_OP:
20+
#define T ${buffer_scalar_type("uint8")}
21+
#define VEC4_OUT_T ${texel_type("uint8")}
22+
$else:
23+
#define T ${buffer_scalar_type(DTYPE)}
24+
#define VEC4_OUT_T VEC4_T
1525

1626
#define op(X, Y, A) ${OPERATOR}
1727

1828
${define_active_storage_type(STORAGE)}
1929
${define_required_extensions(DTYPE)}
2030

31+
32+
$if IS_COMPARISON_OP:
33+
${define_required_extensions("uint8")}
34+
2135
layout(std430) buffer;
2236

23-
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
37+
$if IS_COMPARISON_OP:
38+
${layout_declare_tensor(B, "w", "t_out", "uint8", STORAGE)}
39+
$else:
40+
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
41+
2442
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2543
${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
2644

@@ -121,7 +139,7 @@ void main() {
121139
write_texel_lpos(
122140
t_out,
123141
lpos,
124-
VEC4_T(op(in_texel, other_texel, alpha)),
142+
VEC4_OUT_T(op(in_texel, other_texel, alpha)),
125143
out_axis_map);
126144
}
127145

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,84 @@ binary_op:
3232
OPERATOR: floor(X / Y)
3333
- NAME: binary_minimum
3434
OPERATOR: min(X, Y)
35+
- NAME: binary_eq_int32
36+
OPERATOR: X == Y
37+
DTYPE: int32
38+
- NAME: binary_eq_buffer
39+
OPERATOR: abs(X - Y) < 1e-5
40+
STORAGE: buffer
41+
generate_variant_forall:
42+
DTYPE:
43+
- VALUE: half
44+
- VALUE: float
45+
- NAME: binary_eq_texture3d
46+
OPERATOR: all(lessThanEqual(abs(X - Y), VEC4_T(1e-5)))
47+
STORAGE: texture3d
48+
generate_variant_forall:
49+
DTYPE:
50+
- VALUE: half
51+
- VALUE: float
52+
- NAME: binary_lt_buffer
53+
OPERATOR: X < Y
54+
STORAGE: buffer
55+
generate_variant_forall:
56+
DTYPE:
57+
- VALUE: half
58+
- VALUE: float
59+
- VALUE: int32
60+
- NAME: binary_lt_texture3d
61+
OPERATOR: all(lessThan(X, Y))
62+
STORAGE: texture3d
63+
generate_variant_forall:
64+
DTYPE:
65+
- VALUE: half
66+
- VALUE: float
67+
- VALUE: int32
68+
- NAME: binary_le_buffer
69+
OPERATOR: X <= Y
70+
STORAGE: buffer
71+
generate_variant_forall:
72+
DTYPE:
73+
- VALUE: half
74+
- VALUE: float
75+
- VALUE: int32
76+
- NAME: binary_le_texture3d
77+
OPERATOR: all(lessThanEqual(X, Y))
78+
STORAGE: texture3d
79+
generate_variant_forall:
80+
DTYPE:
81+
- VALUE: half
82+
- VALUE: float
83+
- VALUE: int32
84+
- NAME: binary_gt_buffer
85+
OPERATOR: X > Y
86+
STORAGE: buffer
87+
generate_variant_forall:
88+
DTYPE:
89+
- VALUE: half
90+
- VALUE: float
91+
- VALUE: int32
92+
- NAME: binary_gt_texture3d
93+
OPERATOR: all(greaterThan(X, Y))
94+
STORAGE: texture3d
95+
generate_variant_forall:
96+
DTYPE:
97+
- VALUE: half
98+
- VALUE: float
99+
- VALUE: int32
100+
- NAME: binary_ge_buffer
101+
OPERATOR: X >= Y
102+
STORAGE: buffer
103+
generate_variant_forall:
104+
DTYPE:
105+
- VALUE: half
106+
- VALUE: float
107+
- VALUE: int32
108+
- NAME: binary_ge_texture3d
109+
OPERATOR: all(greaterThanEqual(X, Y))
110+
STORAGE: texture3d
111+
generate_variant_forall:
112+
DTYPE:
113+
- VALUE: half
114+
- VALUE: float
115+
- VALUE: int32

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ void add_binary_op_texture_node(
7777
kernel_name.reserve(kShaderNameReserve);
7878
kernel_name += op_name;
7979
add_storage_type_suffix(kernel_name, *t_out);
80-
add_dtype_suffix(kernel_name, *t_out);
80+
add_dtype_suffix(kernel_name, graph.dtype_of(in1));
8181

8282
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
8383
graph,
@@ -121,7 +121,8 @@ void add_binary_op_buffer_node(
121121
kernel_name.reserve(kShaderNameReserve);
122122
kernel_name += op_name;
123123
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
124-
add_dtype_suffix(kernel_name, graph.dtype_of(out));
124+
125+
add_dtype_suffix(kernel_name, graph.dtype_of(in1));
125126

126127
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
127128
graph,
@@ -189,6 +190,11 @@ DEFINE_BINARY_OP_FN(mul);
189190
DEFINE_BINARY_OP_FN(div);
190191
DEFINE_BINARY_OP_FN(pow);
191192
DEFINE_BINARY_OP_FN(minimum);
193+
DEFINE_BINARY_OP_FN(eq);
194+
DEFINE_BINARY_OP_FN(lt);
195+
DEFINE_BINARY_OP_FN(le);
196+
DEFINE_BINARY_OP_FN(gt);
197+
DEFINE_BINARY_OP_FN(ge);
192198

193199
REGISTER_OPERATORS {
194200
VK_REGISTER_OP(aten.add.Tensor, add);
@@ -198,6 +204,11 @@ REGISTER_OPERATORS {
198204
VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
199205
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
200206
VK_REGISTER_OP(aten.minimum.default, minimum);
207+
VK_REGISTER_OP(aten.eq.Tensor, eq);
208+
VK_REGISTER_OP(aten.lt.Tensor, lt);
209+
VK_REGISTER_OP(aten.le.Tensor, le);
210+
VK_REGISTER_OP(aten.gt.Tensor, gt);
211+
VK_REGISTER_OP(aten.ge.Tensor, ge);
201212
}
202213

203214
} // namespace vkcompute

backends/vulkan/test/op_tests/cases.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,42 @@ def get_binary_elementwise_inputs():
6363
"utils::kBuffer",
6464
"utils::kTexture3D",
6565
]
66+
67+
return test_suite
68+
69+
70+
# Eq requires a different test generator so it was split from the other test case.
71+
@register_test_suite(
72+
[
73+
"aten.eq.Tensor",
74+
"aten.gt.Tensor",
75+
"aten.lt.Tensor",
76+
"aten.ge.Tensor",
77+
"aten.le.Tensor",
78+
]
79+
)
80+
def get_binary_elementwise_compare_inputs():
81+
test_suite = VkTestSuite(
82+
[
83+
((M1, M2), (M1, M2)),
84+
((M1, M2), (M1, 1), 2.0),
85+
((M1, M2), (1, M2)),
86+
((S, S1, S2), (S, S1, S2)),
87+
((S, S1, S2), (S, S1, 1), 2.0),
88+
((S, S1, S2), (S, 1, S2), 2.0),
89+
((XS, S, S1, S2), (XS, S, 1, 1), 2.0),
90+
((3, 64, 1), (1, 64, 1)),
91+
]
92+
)
93+
test_suite.layouts = [
94+
"utils::kWidthPacked",
95+
"utils::kChannelsPacked",
96+
]
97+
test_suite.storage_types = [
98+
"utils::kBuffer",
99+
"utils::kTexture3D",
100+
]
101+
test_suite.data_gen = "make_casted_randint_tensor"
66102
return test_suite
67103

68104

backends/vulkan/test/op_tests/utils/gen_benchmark_vk.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,15 @@ def generate_benchmark_fixture(self) -> str:
196196
}}
197197
}}
198198
199+
at::Tensor make_casted_randint_tensor(
200+
std::vector<int64_t> sizes,
201+
at::ScalarType dtype = at::kFloat,
202+
int low = 0,
203+
int high = 10) {{
204+
205+
return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype));
206+
}}
207+
199208
at::Tensor make_rand_tensor(
200209
std::vector<int64_t> sizes,
201210
at::ScalarType dtype = at::kFloat,

backends/vulkan/test/op_tests/utils/gen_correctness_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,15 @@ def generate_suite_cpp(self) -> str:
283283
284284
{preamble}
285285
286+
at::Tensor make_casted_randint_tensor(
287+
std::vector<int64_t> sizes,
288+
at::ScalarType dtype = at::kFloat,
289+
int low = 0,
290+
int high = 10) {{
291+
292+
return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype));
293+
}}
294+
286295
at::Tensor make_rand_tensor(
287296
std::vector<int64_t> sizes,
288297
at::ScalarType dtype = at::kFloat,

0 commit comments

Comments
 (0)