Skip to content

Commit 98c57a6

Browse files
committed
Update on "[ET-VK] Implement prepack nodes"
## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) [ghstack-poisoned]
2 parents 499f8f2 + 28c2cf6 commit 98c57a6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ void add_binary_op_node(
5151
const ValueRef alpha,
5252
const ValueRef out,
5353
const std::string& op_name) {
54-
vTensorPtr t_in1 = graph.get_tensor(in1);
55-
vTensorPtr t_in2 = graph.get_tensor(in2);
54+
ValueRef arg1 = prepack_standard_like(graph, in1, out, true);
55+
ValueRef arg2 = prepack_standard_like(graph, in2, out, true);
56+
57+
vTensorPtr t_in1 = graph.get_tensor(arg1);
58+
vTensorPtr t_in2 = graph.get_tensor(arg2);
5659
vTensorPtr t_out = graph.get_tensor(out);
5760

5861
check_binary_op_args(*t_in1, *t_in2, *t_out);
@@ -78,7 +81,7 @@ void add_binary_op_node(
7881
graph.create_local_wg_size(out),
7982
// Inputs and Outputs
8083
{{out, vkapi::MemoryAccessType::WRITE},
81-
{{in1, in2}, vkapi::MemoryAccessType::READ}},
84+
{{arg1, arg2}, vkapi::MemoryAccessType::READ}},
8285
// Shader params buffers
8386
{t_out->sizes_ubo(),
8487
t_out->axis_map_ubo(),

0 commit comments

Comments
 (0)