Skip to content

Commit 759b4e9

Browse files
author
ssjia
committed
[ET-VK][ez] Introduce a graph config setting to force resize functions to execute
Pull Request resolved: #15158 Title says it all! A few months back, a mechanism was introduced where an `ExecuteNode` would not call an operator's resize function if none of the arguments were updated. However, this creates a blind spot during testing where the resize function of operators are not tested since the generated operator tests do not modify input sizes. To address this, add a way to force the resize function to be called during testing. ghstack-source-id: 317069706 @exported-using-ghexport Differential Revision: [D84716451](https://our.internmc.facebook.com/intern/diff/D84716451/)
1 parent 26a9aec commit 759b4e9

File tree

4 files changed

+8
-2
lines changed

4 files changed

+8
-2
lines changed

backends/vulkan/runtime/graph/GraphConfig.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ GraphConfig::GraphConfig() {
6565
local_wg_size_override = {};
6666

6767
expect_dynamic_shapes = false;
68+
force_resize = false;
6869

6970
external_adapter = nullptr;
7071
}

backends/vulkan/runtime/graph/GraphConfig.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ struct GraphConfig final {
3535

3636
// Whether or not the ComputeGraph should expect input shapes to be dynamic
3737
bool expect_dynamic_shapes;
38+
// Used for testing/debugging only. Forces ExecuteNode to trigger the resize
39+
// function even if none of the inputs have been updated.
40+
bool force_resize = false;
3841

3942
// Execution properties that determine specifics re: how command buffer
4043
// submission is handled, etc. 0 means this field is not set.

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ ExecuteNode::ExecuteNode(
2121
name_(name) {}
2222

2323
bool ExecuteNode::trigger_resize(ComputeGraph* graph) {
24-
const bool any_arg_updated = was_any_arg_updated(graph);
25-
if (resize_fn_ && any_arg_updated) {
24+
bool any_arg_updated = was_any_arg_updated(graph);
25+
if (resize_fn_ && (any_arg_updated || graph->graphconfig().force_resize)) {
2626
resize_fn_(graph, args_, resize_args_);
27+
any_arg_updated = true;
2728
}
2829
return any_arg_updated;
2930
}

backends/vulkan/test/op_tests/utils/gen_correctness_vk.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
3434
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
3535
config.set_storage_type_override(default_storage_type);
3636
config.set_memory_layout_override(default_memory_layout);
37+
config.force_resize = true;
3738
graph = new ComputeGraph(config);
3839
3940
if (test_dtype == at::kHalf) {{

0 commit comments

Comments
 (0)