Skip to content

Commit f3a1092

Browse files
committed
Update on "[ET-VK] Implement prepack nodes"
## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) [ghstack-poisoned]
2 parents ff45ace + b0f75f9 commit f3a1092

File tree

6 files changed

+56
-52
lines changed

6 files changed

+56
-52
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def is_non_weight_param_tensor(node: torch.fx.Node) -> bool:
8686
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
8787
# memory object. This pass must be executed AFTER the memory planning pass.
8888
prepack_node.meta["spec"].mem_obj_id = -1
89-
node.replace_all_uses_with(prepack_node, lambda x: x != prepack_node)
89+
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
9090

9191
program.graph.eliminate_dead_code()
9292
return program

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ void add_binary_op_node(
5151
const ValueRef alpha,
5252
const ValueRef out,
5353
const std::string& op_name) {
54-
VK_CHECK_COND(graph.val_is_tensor(in1));
55-
VK_CHECK_COND(graph.val_is_tensor(in2));
56-
5754
vTensorPtr t_in1 = graph.get_tensor(in1);
5855
vTensorPtr t_in2 = graph.get_tensor(in2);
5956
vTensorPtr t_out = graph.get_tensor(out);

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ void add_addmm_naive_node(
9595
const Params& params,
9696
const ValueRef mat2_is_transposed) {
9797
utils::StorageType stype = graph.storage_type_of(out);
98-
ValueRef self =
99-
prepack_standard(graph, self_data, stype, utils::kWidthPacked, true);
100-
ValueRef mat2 =
101-
prepack_standard(graph, mat2_data, stype, utils::kHeightPacked, true);
98+
ValueRef self = prepack_standard(
99+
graph, self_data, stype, utils::kWidthPacked, /*passthrough = */ true);
100+
ValueRef mat2 = prepack_standard(
101+
graph, mat2_data, stype, utils::kHeightPacked, /*passthrough = */ true);
102102

103103
std::string kernel_name =
104104
graph.get_bool(mat2_is_transposed) ? "linear_naive" : "addmm_naive";
@@ -149,10 +149,10 @@ void add_addmm_optimized_node(
149149
const Params& params,
150150
const ValueRef mat2_is_transposed) {
151151
utils::StorageType stype = graph.storage_type_of(out);
152-
ValueRef self =
153-
prepack_standard(graph, self_data, stype, utils::kChannelsPacked, true);
154-
ValueRef mat2 =
155-
prepack_standard(graph, mat2_data, stype, utils::kHeightPacked, true);
152+
ValueRef self = prepack_standard(
153+
graph, self_data, stype, utils::kChannelsPacked, /*passthrough=*/true);
154+
ValueRef mat2 = prepack_standard(
155+
graph, mat2_data, stype, utils::kHeightPacked, /*passthrough=*/true);
156156

157157
// Ensure mat1 is width packed
158158
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ void add_matmul_naive_buffer_node(
6363
const ValueRef out,
6464
const ValueRef mat2_is_transposed) {
6565
ValueRef mat2 = prepack_standard(
66-
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
66+
graph,
67+
mat2_data,
68+
graph.storage_type_of(out),
69+
utils::kHeightPacked,
70+
/*passthrough = */ true);
6771

6872
std::string kernel_name = "matmul_naive_buffer";
6973
add_dtype_suffix(kernel_name, graph.dtype_of(out));
@@ -105,7 +109,11 @@ void add_matmul_naive_texture3d_node(
105109
const ValueRef out,
106110
const ValueRef mat2_is_transposed) {
107111
ValueRef mat2 = prepack_standard(
108-
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
112+
graph,
113+
mat2_data,
114+
graph.storage_type_of(out),
115+
utils::kHeightPacked,
116+
/*passthrough = */ true);
109117

110118
std::string kernel_name = graph.get_bool(mat2_is_transposed)
111119
? "matmul_transposed_naive"
@@ -149,7 +157,11 @@ void add_matmul_optimized_node(
149157
const ValueRef out,
150158
const ValueRef mat2_is_transposed) {
151159
ValueRef mat2 = prepack_standard(
152-
graph, mat2_data, graph.storage_type_of(out), utils::kHeightPacked, true);
160+
graph,
161+
mat2_data,
162+
graph.storage_type_of(out),
163+
utils::kHeightPacked,
164+
/*passthrough = */ true);
153165

154166
// Ensure mat1 is width packed
155167
ValueRef mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void add_tensor_to_staging_node(
112112
{SV(graph.packed_dim_of(in_tensor))}));
113113
}
114114

115-
void add_standard_prepack_node(
115+
void add_prepack_standard_node(
116116
ComputeGraph& graph,
117117
const ValueRef tensor_data,
118118
const ValueRef tensor) {
@@ -154,7 +154,7 @@ ValueRef prepack_standard(
154154
}
155155
VK_CHECK_COND(graph.val_is_tref(tensor_data));
156156
ValueRef tensor = graph.add_tensor_like(tensor_data, storage_type, layout);
157-
add_standard_prepack_node(graph, tensor_data, tensor);
157+
add_prepack_standard_node(graph, tensor_data, tensor);
158158
return tensor;
159159
}
160160

@@ -172,7 +172,7 @@ ValueRef prepack_standard_like(
172172
passthrough);
173173
}
174174

175-
void add_direct_buffer_copy_prepack_node(
175+
void add_prepack_direct_buffer_copy_node(
176176
ComputeGraph& graph,
177177
const ValueRef tensor_data,
178178
const ValueRef tensor) {
@@ -200,14 +200,15 @@ void add_direct_buffer_copy_prepack_node(
200200
ValueRef prepack_direct_copy_buffer(
201201
ComputeGraph& graph,
202202
const ValueRef tensor_data) {
203+
VK_CHECK_COND(graph.val_is_tref(tensor_data));
203204
ValueRef tensor =
204205
graph.add_tensor_like(tensor_data, utils::kBuffer, utils::kWidthPacked);
205-
add_direct_buffer_copy_prepack_node(graph, tensor_data, tensor);
206+
add_prepack_direct_buffer_copy_node(graph, tensor_data, tensor);
206207
return tensor;
207208
}
208209

209210
void prepack_op(ComputeGraph& graph, const std::vector<ValueRef>& args) {
210-
return add_standard_prepack_node(graph, args[0], args[1]);
211+
return add_prepack_standard_node(graph, args[0], args[1]);
211212
}
212213

213214
REGISTER_OPERATORS {

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,11 +1520,18 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) {
15201520
ValueRef c = graph.add_tensor(size_big, vkapi::kFloat);
15211521
ValueRef e = graph.add_tensor(size_big, vkapi::kFloat);
15221522

1523+
ValueRef w1_packed = graph.add_tensor(size_small, vkapi::kFloat);
1524+
ValueRef w2_packed = graph.add_tensor(size_small, vkapi::kFloat);
1525+
1526+
auto prepackFn = VK_GET_OP_FN("et_vk.prepack.default");
1527+
prepackFn(graph, {w1, w1_packed});
1528+
prepackFn(graph, {w2, w2_packed});
1529+
15231530
auto addFn = VK_GET_OP_FN("aten.add.Tensor");
1524-
addFn(graph, {a.value, w1, kDummyValueRef, c});
1531+
addFn(graph, {a.value, w1_packed, kDummyValueRef, c});
15251532

15261533
auto mulFn = VK_GET_OP_FN("aten.mul.Tensor");
1527-
mulFn(graph, {c, w2, e});
1534+
mulFn(graph, {c, w2_packed, e});
15281535

15291536
IOValueRef out = {};
15301537
out.value = e;
@@ -2597,8 +2604,7 @@ void test_binary_op(
25972604
std::vector<int64_t> sizes_big,
25982605
std::vector<int64_t> sizes_small,
25992606
vkapi::ScalarType dtype,
2600-
utils::GPUMemoryLayout memory_layout,
2601-
bool prepack = true) {
2607+
utils::GPUMemoryLayout memory_layout) {
26022608
GraphConfig config;
26032609
ComputeGraph graph(config);
26042610

@@ -2609,12 +2615,7 @@ void test_binary_op(
26092615
// Build graph
26102616

26112617
IOValueRef arg1 = graph.add_input_tensor(sizes_big, dtype, memory_layout);
2612-
2613-
if (prepack) {
2614-
arg2.value = arg2_w;
2615-
} else {
2616-
arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout);
2617-
}
2618+
arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout);
26182619

26192620
IOValueRef out;
26202621
out.value = graph.add_tensor(sizes_big, dtype, memory_layout);
@@ -2635,7 +2636,7 @@ void test_binary_op(
26352636

26362637
for (int i = 1; i < 4; i++) {
26372638
float val_arg1 = i + 1.5;
2638-
float val_arg2 = prepack ? 2.5f : i - 3.5;
2639+
float val_arg2 = i - 3.5;
26392640

26402641
float val_out = val_arg1 + val_arg2;
26412642
if (op_name == "sub") {
@@ -2648,21 +2649,14 @@ void test_binary_op(
26482649
val_out = val_arg1 / val_arg2;
26492650
}
26502651

2651-
if (prepack) {
2652-
execute_graph_and_check_output(graph, {val_arg1}, {val_out});
2653-
} else {
2654-
execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out});
2655-
}
2652+
execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out});
26562653
}
26572654
}
26582655

2659-
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
2660-
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
2661-
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, false) \
2662-
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \
2663-
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \
2664-
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, true) \
2665-
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true)
2656+
#define CALL_TEST_FN_FORALL_CONDITIONS(_) \
2657+
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked) \
2658+
_(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked) \
2659+
_(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked)
26662660

26672661
#define CALL_TEST_FN_FOR_W_PACKED(_) \
26682662
_(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \
@@ -2677,15 +2671,15 @@ void test_binary_op(
26772671
_(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, true)
26782672

26792673
TEST(VulkanComputeGraphOpsTest, add_smoke_test) {
2680-
#define RUN_TESTS(dtype, storage, layout, prepack) \
2681-
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout, prepack); \
2682-
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout, prepack); \
2683-
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout, prepack); \
2684-
test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout, prepack); \
2685-
test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout, prepack); \
2686-
test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout, prepack); \
2687-
test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout, prepack); \
2688-
test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout, prepack);
2674+
#define RUN_TESTS(dtype, storage, layout) \
2675+
test_binary_op("add", {17, 21}, {17, 21}, dtype, layout); \
2676+
test_binary_op("add", {17, 21}, {1, 1}, dtype, layout); \
2677+
test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout); \
2678+
test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout); \
2679+
test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout); \
2680+
test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout); \
2681+
test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout); \
2682+
test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout);
26892683

26902684
CALL_TEST_FN_FORALL_CONDITIONS(RUN_TESTS);
26912685

0 commit comments

Comments
 (0)