1010
1111#include < executorch/backends/vulkan/runtime/graph/Logging.h>
1212
13+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
14+
1315#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1416#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1517#include < executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1618
1719namespace vkcompute {
1820
21+ void resize_clone_node (
22+ ComputeGraph* graph,
23+ const std::vector<ArgGroup>& args,
24+ const std::vector<ValueRef>& extra_args) {
25+ (void )extra_args;
26+ vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
27+ vTensorPtr in = graph->get_tensor (args[1 ].refs [0 ]);
28+ // TODO: support for when dimensionality doesn't match, i.e. clone is used to
29+ // implement squeeze.
30+ if (out->dim () == in->dim ()) {
31+ out->virtual_resize (in->sizes ());
32+ }
33+ }
34+
1935void add_clone_node (
2036 ComputeGraph& graph,
2137 const ValueRef in,
@@ -30,14 +46,84 @@ void add_clone_node(
3046 VK_KERNEL_FROM_STR (kernel_name),
3147 graph.create_global_wg_size (out),
3248 graph.create_local_wg_size (out),
33- {{out, vkapi::MemoryAccessType::WRITE},
34- {in, vkapi::MemoryAccessType::READ}},
35- {t_out->logical_limits_ubo ()}));
49+ // Inputs and Outputs
50+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
51+ // Parameter Buffers
52+ {t_out->logical_limits_ubo ()},
53+ // Specialization Constants
54+ {},
55+ // Resizing Logic
56+ resize_clone_node));
57+ }
58+
59+ void add_image_to_buffer_node (
60+ ComputeGraph& graph,
61+ const ValueRef image,
62+ const ValueRef buffer) {
63+ std::string kernel_name = " clone_image_to_buffer" ;
64+ add_dtype_suffix (kernel_name, graph.dtype_of (image));
65+ vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (kernel_name);
66+
67+ utils::uvec3 global_wg_size = graph.create_global_wg_size (image);
68+ graph.execute_nodes ().emplace_back (new DispatchNode (
69+ graph,
70+ shader,
71+ global_wg_size,
72+ graph.create_local_wg_size (global_wg_size),
73+ // Input and Outputs
74+ {{buffer, vkapi::kWrite }, {image, vkapi::kRead }},
75+ // Parameter Buffers
76+ {graph.sizes_ubo (image), graph.strides_ubo (buffer)},
77+ // Specialization Constants
78+ {graph.hashed_layout_of (image)},
79+ // Resizing Logic
80+ resize_clone_node));
81+ }
82+
83+ void add_buffer_to_image_node (
84+ ComputeGraph& graph,
85+ const ValueRef buffer,
86+ const ValueRef image) {
87+ std::string kernel_name = " clone_buffer_to_image" ;
88+ add_dtype_suffix (kernel_name, graph.dtype_of (image));
89+ vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (kernel_name);
90+
91+ utils::uvec3 global_wg_size = graph.create_global_wg_size (image);
92+ graph.execute_nodes ().emplace_back (new DispatchNode (
93+ graph,
94+ shader,
95+ global_wg_size,
96+ graph.create_local_wg_size (global_wg_size),
97+ // Input and Outputs
98+ {{image, vkapi::kWrite }, {buffer, vkapi::kRead }},
99+ // Parameter Buffers
100+ {graph.sizes_ubo (image), graph.strides_ubo (buffer)},
101+ // Specialization Constants
102+ {graph.hashed_layout_of (image)},
103+ // Resizing Logic
104+ resize_clone_node));
36105}
37106
38107void clone (ComputeGraph& graph, const std::vector<ValueRef>& args) {
39- // The vulkan delegate does not support changing memory format.
40- return add_clone_node (graph, args[0 ], args[2 ]);
108+ const ValueRef src = args[0 ];
109+ const ValueRef dst = args[2 ];
110+
111+ const utils::StorageType src_storage = graph.storage_type_of (src);
112+ const utils::StorageType dst_storage = graph.storage_type_of (dst);
113+ if (src_storage == utils::kTexture3D && dst_storage == utils::kTexture3D ) {
114+ if (graph.hashed_layout_of (src) == graph.hashed_layout_of (dst)) {
115+ return add_clone_node (graph, src, dst);
116+ } else {
117+ return add_view_node (graph, src, kDummyValueRef , dst);
118+ }
119+ }
120+ if (src_storage == utils::kTexture3D && dst_storage == utils::kBuffer ) {
121+ return add_image_to_buffer_node (graph, src, dst);
122+ }
123+ if (src_storage == utils::kBuffer && dst_storage == utils::kTexture3D ) {
124+ return add_buffer_to_image_node (graph, src, dst);
125+ }
126+ VK_THROW (" Buffer to buffer memory layout transition not supported yet!" );
41127}
42128
43129// Clone node is not the most efficient implementation for the aten.clone
0 commit comments