|
10 | 10 |
|
11 | 11 | #include <executorch/backends/vulkan/runtime/graph/ops/impl/Clone.h>
|
12 | 12 | #include <executorch/backends/vulkan/runtime/graph/ops/impl/Permute.h>
|
| 13 | +#include <executorch/backends/vulkan/runtime/graph/ops/impl/View.h> |
13 | 14 | #include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/KernelUtils.h>
|
14 | 15 | #include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
|
15 | 16 |
|
@@ -55,8 +56,49 @@ void add_squeeze_copy_dims_node(
|
55 | 56 | }
|
56 | 57 | }
|
57 | 58 |
|
| 59 | +void resize_squeeze_node( |
| 60 | + ComputeGraph* graph, |
| 61 | + const std::vector<ArgGroup>& args, |
| 62 | + const std::vector<ValueRef>& extra_args) { |
| 63 | + const ValueRef out = args.at(0).refs.at(0); |
| 64 | + const ValueRef in = args.at(1).refs.at(0); |
| 65 | + const ValueRef dims_ref = extra_args.at(0); |
| 66 | + |
| 67 | + const IntListPtr dims = graph->get_int_list(dims_ref); |
| 68 | + |
| 69 | + std::vector<int64_t> out_sizes = graph->sizes_of(in); |
| 70 | + |
| 71 | + // Remove the dimensions specified in dims if their size is 1 |
| 72 | + for (int64_t dim : *dims) { |
| 73 | + if (dim >= 0 && dim < static_cast<int64_t>(out_sizes.size()) && |
| 74 | + out_sizes[dim] == 1) { |
| 75 | + out_sizes.erase(out_sizes.begin() + dim); |
| 76 | + // After erasing, all subsequent dims shift left by one |
| 77 | + // So we need to decrement all subsequent dims in dims |
| 78 | + for (auto& d : *dims) { |
| 79 | + if (d > dim) { |
| 80 | + --d; |
| 81 | + } |
| 82 | + } |
| 83 | + } |
| 84 | + } |
| 85 | + |
| 86 | + graph->virtual_resize(out, out_sizes); |
| 87 | +} |
| 88 | + |
58 | 89 | void squeeze_copy_dims(ComputeGraph& graph, const std::vector<ValueRef>& args) {
|
59 |
| - return add_squeeze_copy_dims_node(graph, args[0], args[1], args[2]); |
| 90 | + int idx = 0; |
| 91 | + const ValueRef in = args.at(idx++); |
| 92 | + const ValueRef dims = args.at(idx++); |
| 93 | + const ValueRef out = args.at(idx++); |
| 94 | + |
| 95 | + std::vector<ValueRef> resize_args = {dims}; |
| 96 | + |
| 97 | + if (graph.is_buffer_storage(in)) { |
| 98 | + return add_view_copy_buffer_node( |
| 99 | + graph, in, out, resize_args, resize_squeeze_node); |
| 100 | + } |
| 101 | + return add_squeeze_copy_dims_node(graph, in, dims, out); |
60 | 102 | }
|
61 | 103 |
|
62 | 104 | REGISTER_OPERATORS {
|
|
0 commit comments