Skip to content

Commit f0af466

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Allow clone op to transfer between memory layouts and storage types (pytorch#6607)
Pull Request resolved: pytorch#6596 ## Changes As title. Extend the functionality of the `aten.clone` operator to allow transitioning the storage type and memory layout between the input to the output tensor. ## Context This functionality will be used to transition input tensors to the optimal storage type and memory layout before entering the execution of an op. The transition nodes will be added by a memory metadata tagging pass that will be introduced in a subsequent diff. ghstack-source-id: 251229412 @exported-using-ghexport Differential Revision: [D65277710](https://our.internmc.facebook.com/intern/diff/D65277710/) Co-authored-by: Stephen Jia <[email protected]>
1 parent 27eac48 commit f0af466

14 files changed

+245
-42
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,22 @@ class ComputeGraph final {
612612
return {t, staging};
613613
}
614614

615+
/*
616+
* Add an input tensor with the specified properties along with its staging
617+
* buffer.
618+
*/
619+
inline IOValueRef add_input_tensor(
620+
const std::vector<int64_t>& sizes,
621+
const vkapi::ScalarType dtype,
622+
const utils::StorageType storage_type,
623+
const utils::GPUMemoryLayout memory_layout,
624+
const int64_t shared_object_idx = -1) {
625+
ValueRef t = add_tensor(
626+
sizes, dtype, storage_type, memory_layout, shared_object_idx);
627+
ValueRef staging = set_input_tensor(t);
628+
return {t, staging};
629+
}
630+
615631
SharedObject& get_shared_object(const int64_t idx);
616632

617633
//

backends/vulkan/runtime/graph/ops/glsl/bitw8_image_to_nchw_nobitw8buffer.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ bitw8_image_to_nchw_nobitw8buffer:
99
STORAGE: texture3d
1010
DTYPE: int8
1111
generate_variant_forall:
12-
DTYPE:
13-
- VALUE: int8
14-
- VALUE: uint8
1512
STORAGE:
1613
- VALUE: texture2d
1714
- VALUE: texture3d
15+
DTYPE:
16+
- VALUE: int8
17+
- VALUE: uint8
1818
shader_variants:
1919
- NAME: bitw8_image_to_nchw_nobitw8buffer

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ ${define_required_extensions(DTYPE)}
1919

2020
layout(std430) buffer;
2121

22-
${layout_declare_buffer(B, "w", "nchw_out", DTYPE)}
22+
${layout_declare_buffer(B, "w", "buf_out", DTYPE)}
2323
${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
2424
${layout_declare_ubo(B, "ivec4", "sizes")}
25+
$if not TO_STAGING:
26+
${layout_declare_ubo(B, "ivec4", "buf_strides")}
2527

2628
#include "indexing_utils.h"
2729

@@ -31,23 +33,23 @@ ${layout_declare_spec_const(C, "int", "t_layout", "DEFAULT_LAYOUT")}
3133
const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3234
const lowp int packed_dim = unhash_packed_dim(t_layout);
3335

34-
void write_out_texel(VEC4_T texel, ivec4 tensor_idx) {
35-
const ivec4 buf_indices = tidx_to_nchwi(
36-
tensor_idx,
37-
sizes,
38-
packed_dim);
36+
void write_out_texel(VEC4_T texel, ivec4 tidx) {
37+
$if TO_STAGING:
38+
const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim);
39+
$else:
40+
const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim);
3941

40-
if (tensor_idx[packed_dim] < sizes[packed_dim]) {
41-
nchw_out[buf_indices.x] = BUF_T(texel.x);
42+
if (tidx[packed_dim] < sizes[packed_dim]) {
43+
buf_out[buf_indices.x] = BUF_T(texel.x);
4244
}
43-
if (tensor_idx[packed_dim] + 1 < sizes[packed_dim]) {
44-
nchw_out[buf_indices.y] = BUF_T(texel.y);
45+
if (tidx[packed_dim] + 1 < sizes[packed_dim]) {
46+
buf_out[buf_indices.y] = BUF_T(texel.y);
4547
}
46-
if (tensor_idx[packed_dim] + 2 < sizes[packed_dim]) {
47-
nchw_out[buf_indices.z] = BUF_T(texel.z);
48+
if (tidx[packed_dim] + 2 < sizes[packed_dim]) {
49+
buf_out[buf_indices.z] = BUF_T(texel.z);
4850
}
49-
if (tensor_idx[packed_dim] + 3 < sizes[packed_dim]) {
50-
nchw_out[buf_indices.w] = BUF_T(texel.w);
51+
if (tidx[packed_dim] + 3 < sizes[packed_dim]) {
52+
buf_out[buf_indices.w] = BUF_T(texel.w);
5153
}
5254
}
5355

backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ image_to_nchw:
88
parameter_names_with_default_values:
99
DTYPE: float
1010
STORAGE: texture3d
11+
TO_STAGING: True
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half
1415
- VALUE: float
1516
- VALUE: int
1617
- VALUE: int8
17-
STORAGE:
18-
- VALUE: texture3d
19-
- VALUE: texture2d
2018
shader_variants:
21-
- NAME: image_to_nchw
19+
- NAME: image_to_nchw_texture3d
20+
- NAME: image_to_nchw_texture2d
21+
STORAGE: texture2d
22+
- NAME: clone_image_to_buffer
23+
TO_STAGING: False

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ ivec4 tidx_to_nchwi(const ivec4 tidx, const ivec4 sizes, const int packed_dim) {
8888
return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim];
8989
}
9090

91+
/*
92+
* Get the buffer indices that contain the data of the texel that corresponds to
93+
* to the provided tensor index. Since the texel have 4 elements, 4 buffer
94+
* indices will be retrieved.
95+
*/
96+
ivec4 tidx_to_4bufi(
97+
const ivec4 tidx,
98+
const ivec4 strides,
99+
const int packed_dim) {
100+
int base_i = tidx.x * strides.x + tidx.y * strides.y + tidx.z * strides.z +
101+
tidx.w * strides.w;
102+
103+
return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim];
104+
}
105+
91106
ivec4 nchwi_to_tidx(const int nchwi, const ivec4 sizes) {
92107
return ivec4(
93108
nchwi % sizes.x,

backends/vulkan/runtime/graph/ops/glsl/nchw_to_bitw8_image_nobitw8buffer.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ nchw_to_bitw8_image_nobitw8buffer:
99
STORAGE: texture3d
1010
DTYPE: int8
1111
generate_variant_forall:
12-
DTYPE:
13-
- VALUE: int8
14-
- VALUE: uint8
1512
STORAGE:
1613
- VALUE: texture2d
1714
- VALUE: texture3d
15+
DTYPE:
16+
- VALUE: int8
17+
- VALUE: uint8
1818
shader_variants:
1919
- NAME: nchw_to_bitw8_image_nobitw8buffer

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ layout(std430) buffer;
2222
${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
2323
${layout_declare_buffer(B, "r", "buf_in", DTYPE)}
2424
${layout_declare_ubo(B, "ivec4", "sizes")}
25+
$if not FROM_STAGING:
26+
${layout_declare_ubo(B, "ivec4", "buf_strides")}
2527

2628
#include "indexing_utils.h"
2729

@@ -32,10 +34,10 @@ const lowp ivec4 axis_map = unhash_axis_map(t_layout);
3234
const lowp int packed_dim = unhash_packed_dim(t_layout);
3335

3436
VEC4_T read_texel(ivec4 tidx) {
35-
const ivec4 buf_indices = tidx_to_nchwi(
36-
tidx,
37-
sizes,
38-
packed_dim);
37+
$if FROM_STAGING:
38+
const ivec4 buf_indices = tidx_to_nchwi(tidx, sizes, packed_dim);
39+
$else:
40+
const ivec4 buf_indices = tidx_to_4bufi(tidx, buf_strides, packed_dim);
3941

4042
VEC4_T texel = VEC4_T(0);
4143
if (tidx[packed_dim] < sizes[packed_dim]) {

backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ nchw_to_image:
88
parameter_names_with_default_values:
99
STORAGE: texture3d
1010
DTYPE: float
11+
FROM_STAGING: True
1112
generate_variant_forall:
1213
DTYPE:
1314
- VALUE: half
1415
- VALUE: float
1516
- VALUE: int
1617
- VALUE: int8
17-
STORAGE:
18-
- VALUE: texture3d
19-
- VALUE: texture2d
2018
shader_variants:
21-
- NAME: nchw_to_image
19+
- NAME: nchw_to_image_texture3d
20+
- NAME: nchw_to_image_texture2d
21+
STORAGE: texture2d
22+
- NAME: clone_buffer_to_image
23+
FROM_STAGING: False

backends/vulkan/runtime/graph/ops/impl/Clone.cpp

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,28 @@
1010

1111
#include <executorch/backends/vulkan/runtime/graph/Logging.h>
1212

13+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
14+
1315
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1416
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1517
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1618

1719
namespace vkcompute {
1820

21+
void resize_clone_node(
22+
ComputeGraph* graph,
23+
const std::vector<ArgGroup>& args,
24+
const std::vector<ValueRef>& extra_args) {
25+
(void)extra_args;
26+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
27+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
28+
// TODO: support for when dimensionality doesn't match, i.e. clone is used to
29+
// implement squeeze.
30+
if (out->dim() == in->dim()) {
31+
out->virtual_resize(in->sizes());
32+
}
33+
}
34+
1935
void add_clone_node(
2036
ComputeGraph& graph,
2137
const ValueRef in,
@@ -30,14 +46,84 @@ void add_clone_node(
3046
VK_KERNEL_FROM_STR(kernel_name),
3147
graph.create_global_wg_size(out),
3248
graph.create_local_wg_size(out),
33-
{{out, vkapi::MemoryAccessType::WRITE},
34-
{in, vkapi::MemoryAccessType::READ}},
35-
{t_out->logical_limits_ubo()}));
49+
// Inputs and Outputs
50+
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
51+
// Parameter Buffers
52+
{t_out->logical_limits_ubo()},
53+
// Specialization Constants
54+
{},
55+
// Resizing Logic
56+
resize_clone_node));
57+
}
58+
59+
void add_image_to_buffer_node(
60+
ComputeGraph& graph,
61+
const ValueRef image,
62+
const ValueRef buffer) {
63+
std::string kernel_name = "clone_image_to_buffer";
64+
add_dtype_suffix(kernel_name, graph.dtype_of(image));
65+
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
66+
67+
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
68+
graph.execute_nodes().emplace_back(new DispatchNode(
69+
graph,
70+
shader,
71+
global_wg_size,
72+
graph.create_local_wg_size(global_wg_size),
73+
// Input and Outputs
74+
{{buffer, vkapi::kWrite}, {image, vkapi::kRead}},
75+
// Parameter Buffers
76+
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
77+
// Specialization Constants
78+
{graph.hashed_layout_of(image)},
79+
// Resizing Logic
80+
resize_clone_node));
81+
}
82+
83+
void add_buffer_to_image_node(
84+
ComputeGraph& graph,
85+
const ValueRef buffer,
86+
const ValueRef image) {
87+
std::string kernel_name = "clone_buffer_to_image";
88+
add_dtype_suffix(kernel_name, graph.dtype_of(image));
89+
vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name);
90+
91+
utils::uvec3 global_wg_size = graph.create_global_wg_size(image);
92+
graph.execute_nodes().emplace_back(new DispatchNode(
93+
graph,
94+
shader,
95+
global_wg_size,
96+
graph.create_local_wg_size(global_wg_size),
97+
// Input and Outputs
98+
{{image, vkapi::kWrite}, {buffer, vkapi::kRead}},
99+
// Parameter Buffers
100+
{graph.sizes_ubo(image), graph.strides_ubo(buffer)},
101+
// Specialization Constants
102+
{graph.hashed_layout_of(image)},
103+
// Resizing Logic
104+
resize_clone_node));
36105
}
37106

38107
void clone(ComputeGraph& graph, const std::vector<ValueRef>& args) {
39-
// The vulkan delegate does not support changing memory format.
40-
return add_clone_node(graph, args[0], args[2]);
108+
const ValueRef src = args[0];
109+
const ValueRef dst = args[2];
110+
111+
const utils::StorageType src_storage = graph.storage_type_of(src);
112+
const utils::StorageType dst_storage = graph.storage_type_of(dst);
113+
if (src_storage == utils::kTexture3D && dst_storage == utils::kTexture3D) {
114+
if (graph.hashed_layout_of(src) == graph.hashed_layout_of(dst)) {
115+
return add_clone_node(graph, src, dst);
116+
} else {
117+
return add_view_node(graph, src, kDummyValueRef, dst);
118+
}
119+
}
120+
if (src_storage == utils::kTexture3D && dst_storage == utils::kBuffer) {
121+
return add_image_to_buffer_node(graph, src, dst);
122+
}
123+
if (src_storage == utils::kBuffer && dst_storage == utils::kTexture3D) {
124+
return add_buffer_to_image_node(graph, src, dst);
125+
}
126+
VK_THROW("Buffer to buffer memory layout transition not supported yet!");
41127
}
42128

43129
// Clone node is not the most efficient implementation for the aten.clone

backends/vulkan/runtime/graph/ops/impl/View.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
1010

11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
12+
1113
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1214
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1315
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

0 commit comments

Comments
 (0)