Skip to content

Commit bfc5b17

Browse files
authored
[ET-VK][ez] Enable no-op ExecuteNodes for view ops
Differential Revision: D61666465 Pull Request resolved: #4843
1 parent 87b38cf commit bfc5b17

File tree

3 files changed

+43
-1
lines changed

3 files changed

+43
-1
lines changed

backends/vulkan/runtime/graph/ops/ExecuteNode.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,22 @@ ExecuteNode::ExecuteNode(
3535
graph.update_descriptor_counts(shader, /*execute = */ true);
3636
}
3737

38+
ExecuteNode::ExecuteNode(
39+
const ResizeFunction& resize_fn,
40+
const std::vector<ValueRef>& resize_args)
41+
: shader_(),
42+
global_workgroup_size_({0u, 0u, 0u}),
43+
local_workgroup_size_({0u, 0u, 0u}),
44+
args_(),
45+
params_(),
46+
spec_vars_(),
47+
resize_fn_(resize_fn),
48+
resize_args_(resize_args) {}
49+
3850
void ExecuteNode::encode(ComputeGraph* graph) {
51+
if (!shader_) {
52+
return;
53+
}
3954
api::Context* const context = graph->context();
4055
vkapi::PipelineBarrier pipeline_barrier{};
4156

backends/vulkan/runtime/graph/ops/ExecuteNode.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class ExecuteNode final {
4848
const std::vector<ArgGroup>&,
4949
const std::vector<ValueRef>&)>;
5050

51-
ExecuteNode(
51+
explicit ExecuteNode(
5252
ComputeGraph& graph,
5353
const vkapi::ShaderInfo& shader,
5454
const utils::uvec3& global_workgroup_size,
@@ -59,6 +59,15 @@ class ExecuteNode final {
5959
const ResizeFunction& resize_fn = nullptr,
6060
const std::vector<ValueRef>& resize_args = {});
6161

62+
/*
63+
* This overload of the ExecuteNode constructor is used to register ops which
64+
* update a tensor view. No shader is dispatched, but the node still needs to
65+
* update the view's sizes and strides after a resize.
66+
*/
67+
explicit ExecuteNode(
68+
const ResizeFunction& resize_fn = nullptr,
69+
const std::vector<ValueRef>& resize_args = {});
70+
6271
~ExecuteNode() = default;
6372

6473
void encode(ComputeGraph* graph);
@@ -83,6 +92,11 @@ class ExecuteNode final {
8392
const vkapi::SpecVarList spec_vars_;
8493
const ResizeFunction resize_fn_;
8594
const std::vector<ValueRef> resize_args_;
95+
96+
public:
97+
operator bool() const {
98+
return shader_;
99+
}
86100
};
87101

88102
} // namespace vkcompute

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,6 +982,19 @@ TEST(VulkanComputeGraphTest, test_values_string) {
982982
EXPECT_TRUE(stored == "hello, world");
983983
}
984984

985+
TEST(VulkanComputeGraphTest, empty_init_executenode_test) {
986+
ExecuteNode node(nullptr, {});
987+
EXPECT_FALSE(node);
988+
989+
GraphConfig config;
990+
ComputeGraph graph(config);
991+
992+
// Encode an empty ExecuteNode and check that command buffer encoding does not
993+
// crash.
994+
graph.execute_nodes().emplace_back(new ExecuteNode(nullptr, {}));
995+
EXPECT_NO_FATAL_FAILURE(graph.encode_execute());
996+
}
997+
985998
TEST(VulkanComputeGraphTest, test_zero_dim_tensor) {
986999
GraphConfig config;
9871000
ComputeGraph graph(config);

0 commit comments

Comments
 (0)