Skip to content

Commit 2ee526e

Browse files
authored
[ET-VK] Allow clone op to transfer between memory layouts and storage types
Differential Revision: D65277710 Pull Request resolved: #6596
1 parent 335056c commit 2ee526e

14 files changed

+245
-42
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,22 @@ class ComputeGraph final {
612612
return {t, staging};
613613
}
614614

615+
/*
616+
* Add an input tensor with the specified properties along with its staging
617+
* buffer.
618+
*/
619+
inline IOValueRef add_input_tensor(
620+
const std::vector<int64_t>& sizes,
621+
const vkapi::ScalarType dtype,
622+
const utils::StorageType storage_type,
623+
const utils::GPUMemoryLayout memory_layout,
624+
const int64_t shared_object_idx = -1) {
625+
ValueRef t = add_tensor(
626+
sizes, dtype, storage_type, memory_layout, shared_object_idx);
627+
ValueRef staging = set_input_tensor(t);
628+
return {t, staging};
629+
}
630+
615631
SharedObject& get_shared_object(const int64_t idx);
616632

617633
//

backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ bitw8_image_to_nchw_nobitw8buffer:
99
STORAGE: texture3d
1010
DTYPE: int8
1111
generate_variant_forall:
12-
DTYPE:
13-
- VALUE: int8
14-
- VALUE: uint8
1512
STORAGE:
1613
- VALUE: texture2d
1714
- VALUE: texture3d
15+
DTYPE:
16+
- VALUE: int8
17+
- VALUE: uint8
1818
shader_variants:
1919
- NAME: bitw8_image_to_nchw_nobitw8buffer

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ ${define_required_extensions(DTYPE)}
1919

2020
layout(std430) buffer;
2121

22-
${layout_declare_buffer(B, "w", "nchw_out", DTYPE)}
22+
${layout_declare_buffer(B, "w", "buf_out", DTYPE)}
2323
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2424
${layout_declare_ubo(B, "ivec4", "sizes")}
25+
$if not TO_STAGING:
26+
${layout_declare_ubo(B, "ivec4", "buf_strides")}
2527

2628
#include "indexing_utils.h"
2729

@@ -31,23 +33,23 @@ ${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")}
3133
const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3234
const lowp int packed_dim = unhash_packed_dim(t_layout);
3335

34-
void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
35-
const ivec4 buf_indices = tidx_to_nchwi(
36-
tensor_idx,
37-
sizes,
38-
packed_dim);
36+
void write_out_texel(VEC4_T texel, ivec4 tidx) {
37+
$if TO_STAGING:
38+
const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim);
39+
$else:
40+
const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim);
3941

40-
if (tensor_idx[packed_dim] < sizes[packed_dim]) {
41-
nchw_out[buf_indices.x] = BUF_T(texel.x);
42+
if (tidx[packed_dim] < sizes[packed_dim]) {
43+
buf_out[buf_indices.x] = BUF_T(texel.x);
4244
}
43-
if (tensor_idx[packed_dim] + 1 < sizes[packed_dim]) {
44-
nchw_out[buf_indices.y] = BUF_T(texel.y);
45+
if (tidx[packed_dim] + 1 < sizes[packed_dim]) {
46+
buf_out[buf_indices.y] = BUF_T(texel.y);
4547
}
46-
if (tensor_idx[packed_dim] + 2 < sizes[packed_dim]) {
47-
nchw_out[buf_indices.z] = BUF_T(texel.z);
48+
if (tidx[packed_dim] + 2 < sizes[packed_dim]) {
49+
buf_out[buf_indices.z] = BUF_T(texel.z);
4850
}
49-
if (tensor_idx[packed_dim] + 3 < sizes[packed_dim]) {
50-
nchw_out[buf_indices.w] = BUF_T(texel.w);
51+
if (tidx[packed_dim] + 3 < sizes[packed_dim]) {
52+
buf_out[buf_indices.w] = BUF_T(texel.w);
5153
}
5254
}
5355

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ image_to_nchw:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: texture3d
11+
TO_STAGING: True
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half
1415
- VALUE: float
1516
- VALUE: int
1617
- VALUE: int8
17-
STORAGE:
18-
- VALUE: texture3d
19-
- VALUE: texture2d
2018
shader_variants:
21-
- NAME: image_to_nchw
19+
- NAME: image_to_nchw_texture3d
20+
- NAME: image_to_nchw_texture2d
21+
STORAGE: texture2d
22+
- NAME: clone_image_to_buffer
23+
TO_STAGING: False

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ ivec4 tidx_to_nchwi(const ivec4 tidx, const ivec4 sizes, const int packed_dim) {
8888
return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim];
8989
}
9090

91+
/*
92+
* Get the buffer indices that contain the data of the texel that corresponds to
93+
* to the provided tensor index. Since the texel have 4 elements, 4 buffer
94+
* indices will be retrieved.
95+
*/
96+
ivec4 tidx_to_4bufi(
97+
const ivec4 tidx,
98+
const ivec4 strides,
99+
const int packed_dim) {
100+
int base_i = tidx.x * strides.x + tidx.y * strides.y + tidx.z * strides.z +
101+
tidx.w * strides.w;
102+
103+
return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim];
104+
}
105+
91106
ivec4 nchwi_to_tidx(const int nchwi, const ivec4 sizes) {
92107
return ivec4(
93108
nchwi % sizes.x,

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ nchw_to_bitw8_image_nobitw8buffer:
99
STORAGE: texture3d
1010
DTYPE: int8
1111
generate_variant_forall:
12-
DTYPE:
13-
- VALUE: int8
14-
- VALUE: uint8
1512
STORAGE:
1613
- VALUE: texture2d
1714
- VALUE: texture3d
15+
DTYPE:
16+
- VALUE: int8
17+
- VALUE: uint8
1818
shader_variants:
1919
- NAME: nchw_to_bitw8_image_nobitw8buffer

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ layout(std430) buffer;
2222
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2323
${layout_declare_buffer(B, "r", "buf_in", DTYPE)}
2424
${layout_declare_ubo(B, "ivec4", "sizes")}
25+
$if not FROM_STAGING:
26+
${layout_declare_ubo(B, "ivec4", "buf_strides")}
2527

2628
#include "indexing_utils.h"
2729

@@ -32,10 +34,10 @@ const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3234
const lowp int packed_dim = unhash_packed_dim(t_layout);
3335

3436
VEC4_T read_texel(ivec4 tidx) {
35-
const ivec4 buf_indices = tidx_to_nchwi(
36-
tidx,
37-
sizes,
38-
packed_dim);
37+
$if FROM_STAGING:
38+
const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim);
39+
$else:
40+
const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim);
3941

4042
VEC4_T texel = VEC4_T(0);
4143
if (tidx[packed_dim] < sizes[packed_dim]) {

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ nchw_to_image:
88
parameter_names_with_default_values:
99
STORAGE: texture3d
1010
DTYPE: float
11+
FROM_STAGING: True
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half
1415
- VALUE: float
1516
- VALUE: int
1617
- VALUE: int8
17-
STORAGE:
18-
- VALUE: texture3d
19-
- VALUE: texture2d
2018
shader_variants:
21-
- NAME: nchw_to_image
19+
- NAME: nchw_to_image_texture3d
20+
- NAME: nchw_to_image_texture2d
21+
STORAGE: texture2d
22+
- NAME: clone_buffer_to_image
23+
FROM_STAGING: False

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,28 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
14+
1315
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1416
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1517
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1618

1719
namespace vkcompute {
1820

21+
void resize_clone_node(
22+
ComputeGraph* graph,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& extra_args) {
25+
(void)extra_args;
26+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
27+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
28+
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
29+
// implement squeeze.
30+
if (out->dim() == in->dim()) {
31+
out->virtual_resize(in->sizes());
32+
}
33+
}
34+
1935
void add_clone_node(
2036
ComputeGraph& graph,
2137
const ValueRef in,
@@ -30,14 +46,84 @@ void add_clone_node(
3046
VK_KERNEL_FROM_STR(kernel_name),
3147
graph.create_global_wg_size(out),
3248
graph.create_local_wg_size(out),
33-
{{out, vkapi::MemoryAccessType::WRITE},
34-
{in, vkapi::MemoryAccessType::READ}},
35-
{t_out->logical_limits_ubo()}));
49+
// Inputs and Outputs
50+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
51+
// Parameter Buffers
52+
{t_out->logical_limits_ubo()},
53+
// Specialization Constants
54+
{},
55+
// Resizing Logic
56+
resize_clone_node));
57+
}
58+
59+
void add_image_to_buffer_node(
60+
ComputeGraph& graph,
61+
const ValueRef image,
62+
const ValueRef buffer) {
63+
std::string kernel_name = "clone_image_to_buffer";
64+
add_dtype_suffix(kernel_name, graph.dtype_of(image));
65+
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
66+
67+
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
68+
graph.execute_nodes().emplace_back(new DispatchNode(
69+
graph,
70+
shader,
71+
global_wg_size,
72+
graph.create_local_wg_size(global_wg_size),
73+
// Input and Outputs
74+
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
75+
// Parameter Buffers
76+
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
77+
// Specialization Constants
78+
{graph.hashed_layout_of(image)},
79+
// Resizing Logic
80+
resize_clone_node));
81+
}
82+
83+
void add_buffer_to_image_node(
84+
ComputeGraph& graph,
85+
const ValueRef buffer,
86+
const ValueRef image) {
87+
std::string kernel_name = "clone_buffer_to_image";
88+
add_dtype_suffix(kernel_name, graph.dtype_of(image));
89+
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
90+
91+
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
92+
graph.execute_nodes().emplace_back(new DispatchNode(
93+
graph,
94+
shader,
95+
global_wg_size,
96+
graph.create_local_wg_size(global_wg_size),
97+
// Input and Outputs
98+
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
99+
// Parameter Buffers
100+
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
101+
// Specialization Constants
102+
{graph.hashed_layout_of(image)},
103+
// Resizing Logic
104+
resize_clone_node));
36105
}
37106

38107
void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
39-
// The vulkan delegate does not support changing memory format.
40-
return add_clone_node(graph, args[0], args[2]);
108+
const ValueRef src = args[0];
109+
const ValueRef dst = args[2];
110+
111+
const utils::StorageType src_storage = graph.storage_type_of(src);
112+
const utils::StorageType dst_storage = graph.storage_type_of(dst);
113+
if (src_storage == utils::kTexture3D && dst_storage == utils::kTexture3D) {
114+
if (graph.hashed_layout_of(src) == graph.hashed_layout_of(dst)) {
115+
return add_clone_node(graph, src, dst);
116+
} else {
117+
return add_view_node(graph, src, kDummyValueRef, dst);
118+
}
119+
}
120+
if (src_storage == utils::kTexture3D && dst_storage == utils::kBuffer) {
121+
return add_image_to_buffer_node(graph, src, dst);
122+
}
123+
if (src_storage == utils::kBuffer && dst_storage == utils::kTexture3D) {
124+
return add_buffer_to_image_node(graph, src, dst);
125+
}
126+
VK_THROW("Buffer to buffer memory layout transition not supported yet!");
41127
}
42128

43129
// Clone node is not the most efficient implementation for the aten.clone

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
12+
1113
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1214
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1315
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

0 commit comments

Comments
 (0)