Skip to content

Commit b0f75f9

Browse files
committed
Update base for Update on "[ET-VK] Implement prepack nodes"
## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) [ghstack-poisoned]
1 parent 5394756 commit b0f75f9

File tree

3 files changed

+28
-15
lines changed

3 files changed

+28
-15
lines changed

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ void add_addmm_naive_node(
9595
const Params& params,
9696
const ValueRef mat2_is_transposed) {
9797
utils::StorageType stype = graph.storage_type_of(out);
98-
ValueRef self =
99-
prepack_standard(graph, self_data, stype, utils::kWidthPacked, true);
100-
ValueRef mat2 =
101-
prepack_standard(graph, mat2_data, stype, utils::kHeightPacked, true);
98+
ValueRef self = prepack_standard(
99+
graph, self_data, stype, utils::kWidthPacked, /*passthrough = */ true);
100+
ValueRef mat2 = prepack_standard(
101+
graph, mat2_data, stype, utils::kHeightPacked, /*passthrough = */ true);
102102

103103
std::string kernel_name =
104104
graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive";
@@ -149,10 +149,10 @@ void add_addmm_optimized_node(
149149
const Params& params,
150150
const ValueRef mat2_is_transposed) {
151151
utils::StorageType stype = graph.storage_type_of(out);
152-
ValueRef self =
153-
prepack_standard(graph, self_data, stype, utils::kChannelsPacked, true);
154-
ValueRef mat2 =
155-
prepack_standard(graph, mat2_data, stype, utils::kHeightPacked, true);
152+
ValueRef self = prepack_standard(
153+
graph, self_data, stype, utils::kChannelsPacked, /*passthrough=*/true);
154+
ValueRef mat2 = prepack_standard(
155+
graph, mat2_data, stype, utils::kHeightPacked, /*passthrough=*/true);
156156

157157
// Ensure mat1 is width packed
158158
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ void add_matmul_naive_buffer_node(
6363
const ValueRef out,
6464
const ValueRef mat2_is_transposed) {
6565
ValueRef mat2 = prepack_standard(
66-
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
66+
graph,
67+
mat2_data,
68+
graph.storage_type_of(out),
69+
utils::kHeightPacked,
70+
/*passthrough = */ true);
6771

6872
std::string kernel_name = "matmul_naive_buffer";
6973
add_dtype_suffix(kernel_name, graph.dtype_of(out));
@@ -105,7 +109,11 @@ void add_matmul_naive_texture3d_node(
105109
const ValueRef out,
106110
const ValueRef mat2_is_transposed) {
107111
ValueRef mat2 = prepack_standard(
108-
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
112+
graph,
113+
mat2_data,
114+
graph.storage_type_of(out),
115+
utils::kHeightPacked,
116+
/*passthrough = */ true);
109117

110118
std::string kernel_name = graph.get_bool(mat2_is_transposed)
111119
? "matmul_transposed_naive"
@@ -149,7 +157,11 @@ void add_matmul_optimized_node(
149157
const ValueRef out,
150158
const ValueRef mat2_is_transposed) {
151159
ValueRef mat2 = prepack_standard(
152-
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
160+
graph,
161+
mat2_data,
162+
graph.storage_type_of(out),
163+
utils::kHeightPacked,
164+
/*passthrough = */ true);
153165

154166
// Ensure mat1 is width packed
155167
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void add_tensor_to_staging_node(
110110
{SV(graph.packed_dim_of(in_tensor))}));
111111
}
112112

113-
void add_standard_prepack_node(
113+
void add_prepack_standard_node(
114114
ComputeGraph& graph,
115115
const ValueRef tensor_data,
116116
const ValueRef tensor) {
@@ -152,7 +152,7 @@ ValueRef prepack_standard(
152152
}
153153
VK_CHECK_COND(graph.val_is_tref(tensor_data));
154154
ValueRef tensor = graph.add_tensor_like(tensor_data, storage_type, layout);
155-
add_standard_prepack_node(graph, tensor_data, tensor);
155+
add_prepack_standard_node(graph, tensor_data, tensor);
156156
return tensor;
157157
}
158158

@@ -170,7 +170,7 @@ ValueRef prepack_standard_like(
170170
passthrough);
171171
}
172172

173-
void add_direct_buffer_copy_prepack_node(
173+
void add_prepack_direct_buffer_copy_node(
174174
ComputeGraph& graph,
175175
const ValueRef tensor_data,
176176
const ValueRef tensor) {
@@ -198,9 +198,10 @@ void add_direct_buffer_copy_prepack_node(
198198
ValueRef prepack_direct_copy_buffer(
199199
ComputeGraph& graph,
200200
const ValueRef tensor_data) {
201+
VK_CHECK_COND(graph.val_is_tref(tensor_data));
201202
ValueRef tensor =
202203
graph.add_tensor_like(tensor_data, utils::kBuffer, utils::kWidthPacked);
203-
add_direct_buffer_copy_prepack_node(graph, tensor_data, tensor);
204+
add_prepack_direct_buffer_copy_node(graph, tensor_data, tensor);
204205
return tensor;
205206
}
206207

0 commit comments

Comments
 (0)