1010
1111#include < executorch/backends/vulkan/runtime/graph/Logging.h>
1212
13+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/View.h>
14+
1315#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1416#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1517#include < executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1618
1719namespace vkcompute {
1820
21+ void resize_clone_node (
22+ ComputeGraph* graph,
23+ const std::vector<ArgGroup>& args,
24+ const std::vector<ValueRef>& extra_args) {
25+ (void )extra_args;
26+ vTensorPtr out = graph->get_tensor (args[0 ].refs [0 ]);
27+ vTensorPtr in = graph->get_tensor (args[1 ].refs [0 ]);
28+ out->virtual_resize (in->sizes ());
29+ }
30+
1931void add_clone_node (
2032 ComputeGraph& graph,
2133 const ValueRef in,
@@ -30,14 +42,86 @@ void add_clone_node(
3042 VK_KERNEL_FROM_STR (kernel_name),
3143 graph.create_global_wg_size (out),
3244 graph.create_local_wg_size (out),
33- {{out, vkapi::MemoryAccessType::WRITE},
34- {in, vkapi::MemoryAccessType::READ}},
35- {t_out->logical_limits_ubo ()}));
45+ // Inputs and Outputs
46+ {{out, vkapi::kWrite }, {in, vkapi::kRead }},
47+ // Parameter Buffers
48+ {t_out->logical_limits_ubo ()},
49+ // Specialization Constants
50+ {},
51+ // Resizing Logic
52+ resize_clone_node));
53+ }
54+
55+ void add_image_to_buffer_node (
56+ ComputeGraph& graph,
57+ const ValueRef image,
58+ const ValueRef buffer) {
59+ std::string kernel_name = " image_to_buffer" ;
60+ add_dtype_suffix (kernel_name, graph.dtype_of (image));
61+ vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (kernel_name);
62+
63+ utils::uvec3 global_wg_size = graph.create_global_wg_size (image);
64+ graph.execute_nodes ().emplace_back (new DispatchNode (
65+ graph,
66+ shader,
67+ global_wg_size,
68+ graph.create_local_wg_size (global_wg_size),
69+ // Input and Outputs
70+ {{buffer, vkapi::kWrite }, {image, vkapi::kRead }},
71+ // Parameter Buffers
72+ {graph.sizes_ubo (image), graph.strides_ubo (buffer)},
73+ // Specialization Constants
74+ {graph.hashed_layout_of (image)},
75+ // Resizing Logic
76+ resize_clone_node));
77+ }
78+
79+ void add_buffer_to_image_node (
80+ ComputeGraph& graph,
81+ const ValueRef buffer,
82+ const ValueRef image) {
83+ std::string kernel_name = " buffer_to_image" ;
84+ add_dtype_suffix (kernel_name, graph.dtype_of (image));
85+ vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR (kernel_name);
86+
87+ utils::uvec3 global_wg_size = graph.create_global_wg_size (image);
88+ graph.execute_nodes ().emplace_back (new DispatchNode (
89+ graph,
90+ shader,
91+ global_wg_size,
92+ graph.create_local_wg_size (global_wg_size),
93+ // Input and Outputs
94+ {{image, vkapi::kWrite }, {buffer, vkapi::kRead }},
95+ // Parameter Buffers
96+ {graph.sizes_ubo (image), graph.strides_ubo (buffer)},
97+ // Specialization Constants
98+ {graph.hashed_layout_of (image)},
99+ // Resizing Logic
100+ resize_clone_node));
36101}
37102
38103void clone (ComputeGraph& graph, const std::vector<ValueRef>& args) {
39- // The vulkan delegate does not support changing memory format.
40- return add_clone_node (graph, args[0 ], args[2 ]);
104+ const ValueRef src = args[0 ];
105+ const ValueRef dst = args[2 ];
106+
107+ const utils::StorageType src_storage = graph.storage_type_of (src);
108+ const utils::StorageType dst_storage = graph.storage_type_of (dst);
109+ if (src_storage == dst_storage) {
110+ if (graph.hashed_layout_of (src) == graph.hashed_layout_of (dst)) {
111+ return add_clone_node (graph, src, dst);
112+ } else if (src_storage == utils::kTexture3D ) {
113+ return add_view_node (graph, src, kDummyValueRef , dst);
114+ }
115+ // TODO: Implement memory layout transition for buffer backed tensors
116+ VK_THROW (" Cannot transition memory layout for buffer backed tensors yet" );
117+ }
118+ if (src_storage == utils::kTexture3D && dst_storage == utils::kBuffer ) {
119+ return add_image_to_buffer_node (graph, src, dst);
120+ }
121+ if (src_storage == utils::kBuffer && dst_storage == utils::kTexture3D ) {
122+ return add_buffer_to_image_node (graph, src, dst);
123+ }
124+ VK_THROW (" Invalid storage type transition from src to dst" );
41125}
42126
43127// Clone node is not the most efficient implementation for the aten.clone
0 commit comments