Skip to content

Commit 04db940

Browse files
committed
Update base for Update on "[ET-VK] Implement prepack nodes"
## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) [ghstack-poisoned]
2 parents b0f75f9 + ced983a commit 04db940

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

backends/vulkan/runtime/graph/ops/impl/Staging.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ ValueRef prepack_standard_like(
170170
passthrough);
171171
}
172172

173-
void add_prepack_direct_buffer_copy_node(
173+
void add_prepack_direct_copy_buffer_node(
174174
ComputeGraph& graph,
175175
const ValueRef tensor_data,
176176
const ValueRef tensor) {
@@ -201,7 +201,7 @@ ValueRef prepack_direct_copy_buffer(
201201
VK_CHECK_COND(graph.val_is_tref(tensor_data));
202202
ValueRef tensor =
203203
graph.add_tensor_like(tensor_data, utils::kBuffer, utils::kWidthPacked);
204-
add_prepack_direct_buffer_copy_node(graph, tensor_data, tensor);
204+
add_prepack_direct_copy_buffer_node(graph, tensor_data, tensor);
205205
return tensor;
206206
}
207207

backends/vulkan/runtime/graph/ops/impl/Staging.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void add_tensor_to_staging_node(
3939
* tensor shader. The created `Tensor` value is then returned.
4040
*
4141
* If `passthrough` is `true`, then `v` may be a `Tensor` as well. If `v` is a
42-
* `Tensor`, then it is returned as-is. If `passthrough` if `false` (default),
42+
* `Tensor`, then it is returned as-is. If `passthrough` is `false` (default),
4343
* then an exception will be thrown.
4444
*/
4545

examples/portable/custom_ops/custom_ops_1_out.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ void check_preconditions(const Tensor& in, Tensor& out) {
2020
ET_CHECK_MSG(
2121
out.scalar_type() == ScalarType::Float,
2222
"Expected out tensor to have dtype Float, but got %hhd instead",
23-
out.scalar_type());
23+
static_cast<int8_t>(out.scalar_type()));
2424
ET_CHECK_MSG(
2525
in.scalar_type() == ScalarType::Float,
2626
"Expected in tensor to have dtype Float, but got %hhd instead",
27-
in.scalar_type());
27+
static_cast<int8_t>(in.scalar_type()));
2828
ET_CHECK_MSG(
2929
out.dim() == in.dim(),
3030
"Number of dims of out tensor is not compatible with inputs");

kernels/test/custom_kernel_example/op_relu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ my_relu_out(KernelRuntimeContext& context, const Tensor& input, Tensor& out) {
8787
InvalidArgument,
8888
out,
8989
"Unhandled dtype %hhd",
90-
input.scalar_type());
90+
static_cast<int8_t>(input.scalar_type()));
9191
}
9292
#undef RELU
9393

0 commit comments

Comments
 (0)