@@ -23,8 +23,7 @@ void check_args(
2323 const api::vTensor& in,
2424 const std::vector<int64_t >& repeats,
2525 const api::vTensor& out) {
26- VK_CHECK_COND (check_packed_dim_is (in, WHCN::kChannelsDim ));
27- VK_CHECK_COND (check_packed_dim_is (out, WHCN::kChannelsDim ));
26+ VK_CHECK_COND (check_same_packed_dim (in, out));
2827
2928 VK_CHECK_COND (in.storage_type () == out.storage_type ());
3029 if (in.storage_type () == utils::kTexture2D ) {
@@ -59,147 +58,29 @@ void check_args(
5958
6059} // namespace
6160
62- void add_repeat_channel_node (
63- ComputeGraph& graph,
64- ValueRef in,
65- int64_t repeat_channel,
66- ValueRef out,
67- utils::ivec3& running_range) {
68- vTensorPtr t_in = graph.get_tensor (in);
69- vTensorPtr t_out = graph.get_tensor (out);
70-
71- std::string kernel_name = " repeat_channel" ;
72- kernel_name.reserve (kShaderNameReserve );
73- add_dtype_suffix (kernel_name, *t_out);
74-
75- const std::vector<int64_t >& in_sizes = t_in->sizes ();
76-
77- int32_t in_width = utils::safe_downcast<int32_t >(dim_at<kWidth4D >(in_sizes));
78- int32_t in_height =
79- utils::safe_downcast<int32_t >(dim_at<kHeight4D >(in_sizes));
80- int32_t in_channel =
81- utils::safe_downcast<int32_t >(dim_at<kChannel4D >(in_sizes));
82- int32_t in_batch = utils::safe_downcast<int32_t >(dim_at<kBatch4D >(in_sizes));
83-
84- int32_t out_channel = repeat_channel * in_channel;
85-
86- utils::ivec4 out_whcn_sizes{in_width, in_height, out_channel, in_batch};
87-
88- utils::ivec4 in_whcn_sizes{in_width, in_height, in_channel, in_batch};
89-
90- // Channel packed global work ids
91- running_range[2 ] = out_whcn_sizes[3 ] * utils::div_up_4 (out_whcn_sizes[2 ]);
92- utils::uvec3 global_size = utils::make_uvec3 (running_range);
93- utils::uvec3 local_size = adaptive_work_group_size (global_size);
94-
95- const struct Block final {
96- utils::ivec4 out_sizes;
97- utils::ivec4 in_size;
98- } repeat_channel_args{
99- out_whcn_sizes,
100- in_whcn_sizes,
101- };
102-
103- auto shader = VK_KERNEL_FROM_STR (kernel_name);
104-
105- graph.execute_nodes ().emplace_back (new DispatchNode (
106- graph,
107- VK_KERNEL_FROM_STR (kernel_name),
108- global_size,
109- local_size,
110- // Inputs and Outputs
111- {{out, vkapi::MemoryAccessType::WRITE},
112- {in, vkapi::MemoryAccessType::READ}},
113- // Parameter buffers
114- {graph.create_params_buffer (repeat_channel_args)},
115- // Specialization Constants
116- {SV (t_out->packed_dim ())}));
117- }
118-
11961void add_repeat_node (
12062 ComputeGraph& graph,
12163 ValueRef in,
12264 ValueRef repeats_ref,
12365 ValueRef out) {
124- std::vector<int64_t > repeats = *(graph.get_int_list (repeats_ref));
66+ const std::vector<int64_t > repeats = *(graph.get_int_list (repeats_ref));
12567
12668 vTensorPtr t_in = graph.get_tensor (in);
12769 vTensorPtr t_out = graph.get_tensor (out);
12870 check_args (*t_in, repeats, *t_out);
12971
130- // In this function, we expand the dimensions in the following order:
131- // 1. Channel
132- // 2. Width
133- // 3. Height
134- // 4. Batch
135- // After expanding a dimension, we will update the "running_range" since we
136- // will need to copy the "expanded" area.
137-
138- utils::ivec3 running_range = t_in->logical_limits ();
139-
140- const std::vector<int64_t >& in_sizes = t_in->sizes ();
141-
142- // Since we use channel packing, repeating the channel dimension is the most
143- // complicated and time-consuming, as we need to reason over misaligned
144- // channels. Hence we expand it first to minimize cost. Also, in this first
145- // dimension, we copy over the input texure to the output. In subsequent
146- // dimensions, we read and write from the same tensor.
147-
148- if (int64_t channel_repeat = dim_at<kChannel4D >(repeats);
149- channel_repeat == 1 ) {
150- // If no repeat, short-cut to a direct copy
151- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
152- utils::ivec4 dst_offset{0 , 0 , 0 , 0 };
153-
154- add_copy_offset_node (
155- graph, in, running_range, src_offset, dst_offset, out, false , false );
156-
157- } else {
158- add_repeat_channel_node (graph, in, channel_repeat, out, running_range);
159- }
160-
161- // TODO: refactor width, height, and batch into a common helper function.
162- // Width
163- if (int64_t width_repeat = dim_at<kWidth4D >(repeats); width_repeat > 1 ) {
164- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
165-
166- for (int i = 1 ; i < width_repeat; ++i) {
167- utils::ivec4 dst_offset{i * dim_at<kWidth4D >(in_sizes), 0 , 0 , 0 };
168-
169- add_copy_offset_node (
170- graph, out, running_range, src_offset, dst_offset, out, true , false );
171- }
172-
173- running_range[0 ] = running_range[0 ] * width_repeat;
174- }
175-
176- // Height
177- if (int64_t height_repeat = dim_at<kHeight4D >(repeats); height_repeat > 1 ) {
178- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
179-
180- for (int i = 1 ; i < height_repeat; ++i) {
181- utils::ivec4 dst_offset = {0 , i * dim_at<kHeight4D >(in_sizes), 0 , 0 };
182-
183- add_copy_offset_node (
184- graph, out, running_range, src_offset, dst_offset, out, true , false );
185- }
186-
187- running_range[1 ] = running_range[1 ] * height_repeat;
188- }
189-
190- // Batch
191- if (int64_t batch_repeat = dim_at<kBatch4D >(repeats); batch_repeat > 1 ) {
192- utils::ivec4 src_offset{0 , 0 , 0 , 0 };
193-
194- for (int i = 1 ; i < batch_repeat; ++i) {
195- utils::ivec4 dst_offset = {0 , 0 , i * running_range[2 ], 0 };
196-
197- add_copy_offset_node (
198- graph, out, running_range, src_offset, dst_offset, out, true , false );
199- }
200-
201- running_range[2 ] = running_range[2 ] * batch_repeat;
202- }
72+ const utils::ivec4 src_offset{
73+ dim_at<kWidth4D >(t_in->sizes ()),
74+ dim_at<kHeight4D >(t_in->sizes ()),
75+ dim_at<kChannel4D >(t_in->sizes ()),
76+ dim_at<kBatch4D >(t_in->sizes ())};
77+ const utils::ivec4 dst_offset{
78+ dim_at<kWidth4D >(repeats),
79+ dim_at<kHeight4D >(repeats),
80+ dim_at<kChannel4D >(repeats),
81+ dim_at<kBatch4D >(repeats)};
82+ add_copy_packed_dim_offset_node (
83+ graph, in, t_out->logical_limits (), src_offset, dst_offset, out, true );
20384}
20485
20586void repeat (ComputeGraph& graph, const std::vector<ValueRef>& args) {
0 commit comments