1010
1111#include < executorch/backends/vulkan/runtime/graph/Logging.h>
1212
13+ #include < executorch/backends/vulkan/runtime/graph/ops/impl/Slice.h>
14+
1315#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1416#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
1517#include < executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
@@ -31,7 +33,7 @@ inline int64_t normalize_idx(
3133 return normalize (index, max);
3234}
3335
34- void add_slice_tensor_out_node (
36+ void add_slice_tensor_copy_node (
3537 ComputeGraph& graph,
3638 ValueRef in,
3739 ValueRef dim_ref,
@@ -149,8 +151,126 @@ void add_slice_tensor_out_node(
149151 }
150152}
151153
152- void slice_tensor_out (ComputeGraph& graph, const std::vector<ValueRef>& args) {
153- return add_slice_tensor_out_node (
154+ std::vector<int64_t > get_slice_sizes (
155+ ComputeGraph& graph,
156+ ValueRef in_ref,
157+ ValueRef dim_ref,
158+ ValueRef opt_start_ref,
159+ ValueRef opt_end_ref) {
160+ const int64_t dim = graph.extract_scalar <int64_t >(dim_ref);
161+ std::optional<int64_t > opt_start =
162+ graph.extract_optional_scalar <int64_t >(opt_start_ref);
163+ std::optional<int64_t > opt_end =
164+ graph.extract_optional_scalar <int64_t >(opt_end_ref);
165+
166+ int64_t dim_size = graph.size_at <int64_t >(dim, in_ref);
167+ int64_t start = opt_start.value_or (0 );
168+ int64_t end = opt_end.value_or (dim_size);
169+
170+ start = normalize_idx (start, dim_size, 0 );
171+ end = normalize_idx (end, dim_size, dim_size);
172+
173+ std::vector<int64_t > new_out_sizes = graph.sizes_of (in_ref);
174+ new_out_sizes.at (dim) = end - start;
175+
176+ return new_out_sizes;
177+ }
178+
179+ void resize_slice_view_node (
180+ ComputeGraph* graph,
181+ const std::vector<ArgGroup>& args,
182+ const std::vector<ValueRef>& extra_args) {
183+ (void )args;
184+ vTensorPtr out = graph->get_tensor (extra_args[0 ]);
185+
186+ std::vector<int64_t > new_out_sizes = get_slice_sizes (
187+ *graph,
188+ extra_args[1 ], // input
189+ extra_args[2 ], // dim
190+ extra_args[3 ], // optional start
191+ extra_args[4 ]); // optional end
192+
193+ out->virtual_resize (new_out_sizes);
194+ }
195+
196+ void check_slice_view_args (
197+ ComputeGraph& graph,
198+ ValueRef in_ref,
199+ ValueRef dim_ref,
200+ ValueRef opt_start_ref,
201+ ValueRef opt_end_ref,
202+ ValueRef opt_step_ref,
203+ ValueRef out_ref) {
204+ VK_CHECK_COND (
205+ graph.val_is_view_of (out_ref, in_ref),
206+ " output must be a view of the input" );
207+
208+ const int64_t dim = graph.extract_scalar <int64_t >(dim_ref);
209+ const int64_t dim_size = graph.size_at <int64_t >(dim, in_ref);
210+
211+ int64_t start =
212+ graph.extract_optional_scalar <int64_t >(opt_start_ref).value_or (0 );
213+ int64_t end = graph.extract_optional_scalar <int64_t >(opt_end_ref).value_or (0 );
214+ int64_t step =
215+ graph.extract_optional_scalar <int64_t >(opt_step_ref).value_or (1 );
216+
217+ start = normalize_idx (start, dim_size, 0 );
218+ end = normalize_idx (end, dim_size, dim_size);
219+
220+ // The start idx must be 0; this is to ensure that the start of the slice view
221+ // does not have any offset with respect to the base buffer storage. If the
222+ // offset is nonzero, then it will potentially change upon a resize; however
223+ // the buffer offset of the view tensor will have been "locked in" when the
224+ // descriptor for its buffer storage is bound to a compute shader. Therefore
225+ // there is no way to update the offset of the view once it has been bound.
226+ VK_CHECK_COND (start == 0 , " start must be 0 for slice view" );
227+ VK_CHECK_COND (step == 1 , " step must be 1 for slice view" );
228+
229+ VK_CHECK_COND (
230+ end < dim_size, " end must be less than dim size for slice view" );
231+
232+ // We must also check that all earlier dims in the dim order have a size of 1.
233+ // This ensures that the slice view encompasses a contiguous memory region of
234+ // the source tensor's memory buffer.
235+ std::vector<int64_t > in_sizes = graph.sizes_of (in_ref);
236+ std::vector<int64_t > in_dim_order = graph.dim_order_of (in_ref);
237+ for (int i = 0 ; i < in_dim_order.size (); ++i) {
238+ if (in_dim_order[i] == dim) {
239+ break ;
240+ }
241+ VK_CHECK_COND (in_sizes[in_dim_order[i]] == 1 );
242+ }
243+ }
244+
245+ void add_slice_view_node (
246+ ComputeGraph& graph,
247+ ValueRef in_ref,
248+ ValueRef dim_ref,
249+ ValueRef opt_start_ref,
250+ ValueRef opt_end_ref,
251+ ValueRef opt_step_ref,
252+ ValueRef out_ref) {
253+ check_slice_view_args (
254+ graph,
255+ in_ref,
256+ dim_ref,
257+ opt_start_ref,
258+ opt_end_ref,
259+ opt_step_ref,
260+ out_ref);
261+
262+ std::vector<int64_t > new_out_sizes =
263+ get_slice_sizes (graph, in_ref, dim_ref, opt_start_ref, opt_end_ref);
264+
265+ graph.get_tensor (out_ref)->virtual_resize (new_out_sizes);
266+
267+ graph.execute_nodes ().emplace_back (new ExecuteNode (
268+ resize_slice_view_node,
269+ {out_ref, in_ref, dim_ref, opt_start_ref, opt_end_ref, opt_step_ref}));
270+ }
271+
272+ void slice_tensor_copy (ComputeGraph& graph, const std::vector<ValueRef>& args) {
273+ return add_slice_tensor_copy_node (
154274 graph,
155275 args[0 ],
156276 args[1 ], // dim
@@ -160,9 +280,36 @@ void slice_tensor_out(ComputeGraph& graph, const std::vector<ValueRef>& args) {
160280 args[5 ]);
161281}
162282
283+ void slice_tensor (ComputeGraph& graph, const std::vector<ValueRef>& args) {
284+ ValueRef in = args[0 ];
285+ ValueRef out = args[5 ];
286+
287+ // Special case if out is a view of in
288+ if (graph.val_is_view_of (out, in)) {
289+ add_slice_view_node (
290+ graph,
291+ in,
292+ args[1 ], // dim
293+ args[2 ], // optional start
294+ args[3 ], // optional end
295+ args[4 ], // step
296+ out);
297+ return ;
298+ }
299+
300+ add_slice_tensor_copy_node (
301+ graph,
302+ in,
303+ args[1 ], // dim
304+ args[2 ], // optional start
305+ args[3 ], // optional end
306+ args[4 ], // step
307+ out);
308+ }
309+
163310REGISTER_OPERATORS {
164- VK_REGISTER_OP (aten.slice_copy .Tensor , slice_tensor_out );
165- VK_REGISTER_OP (aten.slice .Tensor , slice_tensor_out );
311+ VK_REGISTER_OP (aten.slice_copy .Tensor , slice_tensor_copy );
312+ VK_REGISTER_OP (aten.slice .Tensor , slice_tensor );
166313}
167314
168315} // namespace vkcompute
0 commit comments