Skip to content

Commit 499f8f2

Browse files
committed
Update on "[ET-VK] Implement prepack nodes"
## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) [ghstack-poisoned]
2 parents f3a1092 + 04db940 commit 499f8f2

File tree

5 files changed

+6
-8
lines changed

5 files changed

+6
-8
lines changed

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ ValueRef prepack_standard_like(
172172
passthrough);
173173
}
174174

175-
void add_prepack_direct_buffer_copy_node(
175+
void add_prepack_direct_copy_buffer_node(
176176
ComputeGraph& graph,
177177
const ValueRef tensor_data,
178178
const ValueRef tensor) {
@@ -203,7 +203,7 @@ ValueRef prepack_direct_copy_buffer(
203203
VK_CHECK_COND(graph.val_is_tref(tensor_data));
204204
ValueRef tensor =
205205
graph.add_tensor_like(tensor_data, utils::kBuffer, utils::kWidthPacked);
206-
add_prepack_direct_buffer_copy_node(graph, tensor_data, tensor);
206+
add_prepack_direct_copy_buffer_node(graph, tensor_data, tensor);
207207
return tensor;
208208
}
209209

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void add_tensor_to_staging_node(
3939
* tensor shader. The created `Tensor` value is then returned.
4040
*
4141
* If `passthrough` is `true`, then `v` may be a `Tensor` as well. If `v` is a
42-
* `Tensor`, then it is returned as-is. If `passthrough` if `false` (default),
42+
* `Tensor`, then it is returned as-is. If `passthrough` is `false` (default),
4343
* then an exception will be thrown.
4444
*/
4545

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,8 +2610,6 @@ void test_binary_op(
26102610

26112611
IOValueRef arg2{};
26122612

2613-
CREATE_WEIGHT_TENSOR(arg2_w, sizes_small, dtype, 2.5f);
2614-
26152613
// Build graph
26162614

26172615
IOValueRef arg1 = graph.add_input_tensor(sizes_big, dtype, memory_layout);

examples/portable/custom_ops/custom_ops_1_out.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ void check_preconditions(const Tensor& in, Tensor& out) {
2020
ET_CHECK_MSG(
2121
out.scalar_type() == ScalarType::Float,
2222
"Expected out tensor to have dtype Float, but got %hhd instead",
23-
out.scalar_type());
23+
static_cast<int8_t>(out.scalar_type()));
2424
ET_CHECK_MSG(
2525
in.scalar_type() == ScalarType::Float,
2626
"Expected in tensor to have dtype Float, but got %hhd instead",
27-
in.scalar_type());
27+
static_cast<int8_t>(in.scalar_type()));
2828
ET_CHECK_MSG(
2929
out.dim() == in.dim(),
3030
"Number of dims of out tensor is not compatible with inputs");

kernels/test/custom_kernel_example/op_relu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ my_relu_out(KernelRuntimeContext& context, const Tensor& input, Tensor& out) {
8787
InvalidArgument,
8888
out,
8989
"Unhandled dtype %hhd",
90-
input.scalar_type());
90+
static_cast<int8_t>(input.scalar_type()));
9191
}
9292
#undef RELU
9393

0 commit comments

Comments
 (0)