Skip to content

Commit 887f82d

Browse files
pytorchbothinriksnaer
authored andcommitted
[ET-VK][testing] Improvement to operator test codegen system (pytorch#11972)
## Changes * Allow test cases to specify storage types / memory layouts for individual args * Allow test cases to specify different data generation functions for individual args ## Motivation > Allow test cases to specify storage types / memory layouts for individual args Make it possible to test args that require specific storage types for certain input/output tensors. > Allow test cases to specify different data generation functions for individual args Useful for debugging operators during development. Differential Revision: [D77038777](https://our.internmc.facebook.com/intern/diff/D77038777/)
1 parent c43ffb9 commit 887f82d

File tree

4 files changed

+96
-7
lines changed

4 files changed

+96
-7
lines changed

backends/vulkan/test/op_tests/utils/gen_computegraph.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class ValueRef:
5858
src_cpp_type: str
5959
is_in: bool = False
6060
is_out: bool = False
61+
fixed_storage_type: Optional[str] = None
62+
fixed_memory_layout: Optional[str] = None
6163
requires_prepack: bool = False
6264
supports_prepack: bool = False
6365
# When is_dynamic_size is true, the underlying object size is not known
@@ -137,20 +139,43 @@ def __init__(
137139
if arg.name in self.suite_def.prepacked_args:
138140
supports_prepack = True
139141

142+
fixed_storage_type = None
143+
if arg.name in self.suite_def.arg_storage_types:
144+
fixed_storage_type = self.suite_def.arg_storage_types[arg.name]
145+
146+
fixed_memory_layout = None
147+
if arg.name in self.suite_def.arg_memory_layouts:
148+
fixed_memory_layout = self.suite_def.arg_memory_layouts[arg.name]
149+
140150
self.refs[arg.name] = ValueRef(
141151
name=f"{arg.name}_ref",
142152
src_cpp_name=arg.name,
143153
src_cpp_type=cpp_type,
144154
is_in=(cpp_type in InableCppType),
155+
fixed_storage_type=fixed_storage_type,
156+
fixed_memory_layout=fixed_memory_layout,
145157
requires_prepack=requires_prepack,
146158
supports_prepack=supports_prepack,
147159
)
148160

149161
ret_type = cpp.returns_type(self.f.func.returns, symint=False).cpp_type()
150162
self.out = ATenArg(name="out", cpp_type=ret_type, default=None)
163+
164+
fixed_storage_type = None
165+
if "out" in self.suite_def.arg_storage_types:
166+
fixed_storage_type = self.suite_def.arg_storage_types["out"]
167+
fixed_memory_layout = None
168+
if "out" in self.suite_def.arg_memory_layouts:
169+
fixed_memory_layout = self.suite_def.arg_memory_layouts["out"]
170+
151171
if ret_type == AT_TENSOR:
152172
self.refs["out"] = ValueRef(
153-
name="out_ref", src_cpp_name="out", src_cpp_type=ret_type, is_out=True
173+
name="out_ref",
174+
src_cpp_name="out",
175+
src_cpp_type=ret_type,
176+
is_out=True,
177+
fixed_storage_type=fixed_storage_type,
178+
fixed_memory_layout=fixed_memory_layout,
154179
)
155180
elif ret_type == TWO_TENSOR_TUPLE:
156181
self.refs["out"] = [
@@ -159,12 +184,24 @@ def __init__(
159184
src_cpp_name="std::get<0>(out)",
160185
src_cpp_type="at::Tensor",
161186
is_out=True,
187+
fixed_storage_type=(
188+
fixed_storage_type[0] if fixed_storage_type else None
189+
),
190+
fixed_memory_layout=(
191+
fixed_memory_layout[0] if fixed_memory_layout else None
192+
),
162193
),
163194
ValueRef(
164195
name="out_ref_second",
165196
src_cpp_name="std::get<1>(out)",
166197
src_cpp_type="at::Tensor",
167198
is_out=True,
199+
fixed_storage_type=(
200+
fixed_storage_type[1] if fixed_storage_type else None
201+
),
202+
fixed_memory_layout=(
203+
fixed_memory_layout[1] if fixed_memory_layout else None
204+
),
168205
),
169206
ValueRef(
170207
name="out_ref",
@@ -180,18 +217,36 @@ def __init__(
180217
src_cpp_name="std::get<0>(out)",
181218
src_cpp_type="at::Tensor",
182219
is_out=True,
220+
fixed_storage_type=(
221+
fixed_storage_type[0] if fixed_storage_type else None
222+
),
223+
fixed_memory_layout=(
224+
fixed_memory_layout[0] if fixed_memory_layout else None
225+
),
183226
),
184227
ValueRef(
185228
name="out_ref_second",
186229
src_cpp_name="std::get<1>(out)",
187230
src_cpp_type="at::Tensor",
188231
is_out=True,
232+
fixed_storage_type=(
233+
fixed_storage_type[1] if fixed_storage_type else None
234+
),
235+
fixed_memory_layout=(
236+
fixed_memory_layout[1] if fixed_memory_layout else None
237+
),
189238
),
190239
ValueRef(
191240
name="out_ref_third",
192241
src_cpp_name="std::get<2>(out)",
193242
src_cpp_type="at::Tensor",
194243
is_out=True,
244+
fixed_storage_type=(
245+
fixed_storage_type[2] if fixed_storage_type else None
246+
),
247+
fixed_memory_layout=(
248+
fixed_memory_layout[2] if fixed_memory_layout else None
249+
),
195250
),
196251
ValueRef(
197252
name="out_ref",
@@ -302,7 +357,12 @@ def create_value_for( # noqa: C901
302357
ret_str += f"{self.graph}{self.dot}"
303358
ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
304359
ret_str += f"{ref.src_cpp_name}->sizes().vec(), "
305-
ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type())); \n"
360+
ret_str += f"from_at_scalartype({ref.src_cpp_name}->scalar_type()"
361+
if ref.fixed_storage_type:
362+
ret_str += f", {ref.fixed_storage_type}"
363+
if ref.fixed_memory_layout:
364+
ret_str += f", {ref.fixed_memory_layout}"
365+
ret_str += "));\n"
306366
elif prepack:
307367
ret_str += f"{self.graph}{self.dot}"
308368
ret_str += f"add_tensorref({ref.src_cpp_name}->sizes().vec(), "
@@ -385,7 +445,12 @@ def create_value_for( # noqa: C901
385445
elif ref.src_cpp_type == AT_TENSOR and not prepack:
386446
ret_str += "add_input_tensor(" if ref.is_in else "add_tensor("
387447
ret_str += f"{ref.src_cpp_name}.sizes().vec(), "
388-
ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())); \n"
448+
ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type())"
449+
if ref.fixed_storage_type:
450+
ret_str += f", {ref.fixed_storage_type}"
451+
if ref.fixed_memory_layout:
452+
ret_str += f", {ref.fixed_memory_layout}"
453+
ret_str += ");\n"
389454
elif ref.src_cpp_type == AT_TENSOR and prepack:
390455
ret_str += f"add_tensorref({ref.src_cpp_name}.sizes().vec(), "
391456
ret_str += f"from_at_scalartype({ref.src_cpp_name}.scalar_type()), "

backends/vulkan/test/op_tests/utils/gen_correctness_base.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,13 @@ def call_data_gen_fn(self, arg: Argument, data: Any, terminate: bool = True) ->
140140
else self.suite_def.arg_data_range[arg.name]
141141
)
142142

143-
ret_str = f"{self.suite_def.data_gen}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})"
143+
data_gen_fn = (
144+
self.suite_def.data_gen
145+
if arg.name not in self.suite_def.arg_data_gen_fn
146+
else self.suite_def.arg_data_gen_fn[arg.name]
147+
)
148+
149+
ret_str = f"{data_gen_fn}({init_list_str(data)}, {tensor_dtype}, {data_range[0]}, {data_range[1]})"
144150
if terminate:
145151
ret_str += ";"
146152

@@ -288,13 +294,29 @@ def generate_suite_cpp(self) -> str:
288294
289295
if (dtype == at::kBool)
290296
return at::rand(sizes, at::device(at::kCPU)) > 0.5;
291-
297+
292298
if (high == 1.0 && low == 0.0)
293299
return at::rand(sizes, at::device(at::kCPU).dtype(dtype));
294300
295301
return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low;
296302
}}
297303
304+
at::Tensor make_zeros_tensor(
305+
std::vector<int64_t> sizes,
306+
at::ScalarType dtype = at::kFloat,
307+
float low = 0.0,
308+
float high = 1.0) {{
309+
return at::zeros(sizes, at::device(at::kCPU).dtype(dtype));
310+
}}
311+
312+
at::Tensor make_ones_tensor(
313+
std::vector<int64_t> sizes,
314+
at::ScalarType dtype = at::kFloat,
315+
float low = 0.0,
316+
float high = 1.0) {{
317+
return at::ones(sizes, at::device(at::kCPU).dtype(dtype));
318+
}}
319+
298320
at::Tensor make_seq_tensor(
299321
std::vector<int64_t> sizes,
300322
at::ScalarType dtype = at::kFloat,

backends/vulkan/test/op_tests/utils/gen_correctness_vk.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
2929
3030
void SetUp() override {{
3131
GraphConfig config;
32-
config.expect_dynamic_shapes = true;
3332
utils::StorageType default_storage_type;
3433
utils::GPUMemoryLayout default_memory_layout;
3534
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();

backends/vulkan/test/op_tests/utils/test_suite.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8-
from typing import Any, List, Optional
8+
from typing import Any, Dict, List, Optional
99

1010
###################################
1111
## Generic Test Suite definition ##
@@ -23,6 +23,7 @@ def __init__(self, input_cases: List[Any]):
2323
self.data_range = (0, 1)
2424

2525
self.arg_dtype = {}
26+
self.arg_data_gen_fn: Dict[str, str] = {}
2627
self.arg_data_range = {}
2728

2829
self.atol: str = "1e-5"
@@ -48,3 +49,5 @@ def __init__(self, input_cases: List[Any]):
4849
self.layouts: List[str] = ["utils::kChannelsPacked"]
4950
self.data_gen: str = "make_rand_tensor"
5051
self.force_io: bool = True
52+
self.arg_storage_types: Dict[str, str] = {}
53+
self.arg_memory_layouts: Dict[str, str] = {}

0 commit comments

Comments
 (0)