@@ -23,8 +23,7 @@ void check_args(
2323    const  api::vTensor& in,
2424    const  std::vector<int64_t >& repeats,
2525    const  api::vTensor& out) {
26-   VK_CHECK_COND (check_packed_dim_is (in, WHCN::kChannelsDim ));
27-   VK_CHECK_COND (check_packed_dim_is (out, WHCN::kChannelsDim ));
26+   VK_CHECK_COND (check_same_packed_dim (in, out));
2827
2928  VK_CHECK_COND (in.storage_type () == out.storage_type ());
3029  if  (in.storage_type () == utils::kTexture2D ) {
@@ -59,147 +58,29 @@ void check_args(
5958
6059} //  namespace
6160
62- void  add_repeat_channel_node (
63-     ComputeGraph& graph,
64-     ValueRef in,
65-     int64_t  repeat_channel,
66-     ValueRef out,
67-     utils::ivec3& running_range) {
68-   vTensorPtr t_in = graph.get_tensor (in);
69-   vTensorPtr t_out = graph.get_tensor (out);
70- 
71-   std::string kernel_name = " repeat_channel" 
72-   kernel_name.reserve (kShaderNameReserve );
73-   add_dtype_suffix (kernel_name, *t_out);
74- 
75-   const  std::vector<int64_t >& in_sizes = t_in->sizes ();
76- 
77-   int32_t  in_width = utils::safe_downcast<int32_t >(dim_at<kWidth4D >(in_sizes));
78-   int32_t  in_height =
79-       utils::safe_downcast<int32_t >(dim_at<kHeight4D >(in_sizes));
80-   int32_t  in_channel =
81-       utils::safe_downcast<int32_t >(dim_at<kChannel4D >(in_sizes));
82-   int32_t  in_batch = utils::safe_downcast<int32_t >(dim_at<kBatch4D >(in_sizes));
83- 
84-   int32_t  out_channel = repeat_channel * in_channel;
85- 
86-   utils::ivec4 out_whcn_sizes{in_width, in_height, out_channel, in_batch};
87- 
88-   utils::ivec4 in_whcn_sizes{in_width, in_height, in_channel, in_batch};
89- 
90-   //  Channel packed global work ids
91-   running_range[2 ] = out_whcn_sizes[3 ] * utils::div_up_4 (out_whcn_sizes[2 ]);
92-   utils::uvec3 global_size = utils::make_uvec3 (running_range);
93-   utils::uvec3 local_size = adaptive_work_group_size (global_size);
94- 
95-   const  struct  Block  final  {
96-     utils::ivec4 out_sizes;
97-     utils::ivec4 in_size;
98-   } repeat_channel_args{
99-       out_whcn_sizes,
100-       in_whcn_sizes,
101-   };
102- 
103-   auto  shader = VK_KERNEL_FROM_STR (kernel_name);
104- 
105-   graph.execute_nodes ().emplace_back (new  DispatchNode (
106-       graph,
107-       VK_KERNEL_FROM_STR (kernel_name),
108-       global_size,
109-       local_size,
110-       //  Inputs and Outputs
111-       {{out, vkapi::MemoryAccessType::WRITE},
112-        {in, vkapi::MemoryAccessType::READ}},
113-       //  Parameter buffers
114-       {graph.create_params_buffer (repeat_channel_args)},
115-       //  Specialization Constants
116-       {SV (t_out->packed_dim ())}));
117- }
118- 
11961void  add_repeat_node (
12062    ComputeGraph& graph,
12163    ValueRef in,
12264    ValueRef repeats_ref,
12365    ValueRef out) {
124-   std::vector<int64_t > repeats = *(graph.get_int_list (repeats_ref));
66+   const   std::vector<int64_t > repeats = *(graph.get_int_list (repeats_ref));
12567
12668  vTensorPtr t_in = graph.get_tensor (in);
12769  vTensorPtr t_out = graph.get_tensor (out);
12870  check_args (*t_in, repeats, *t_out);
12971
130-   //  In this function, we expand the dimensions in the following order:
131-   //  1. Channel
132-   //  2. Width
133-   //  3. Height
134-   //  4. Batch
135-   //  After expanding a dimension, we will update the "running_range" since we
136-   //  will need to copy the "expanded" area.
137- 
138-   utils::ivec3 running_range = t_in->logical_limits ();
139- 
140-   const  std::vector<int64_t >& in_sizes = t_in->sizes ();
141- 
142-   //  Since we use channel packing, repeating the channel dimension is the most
143-   //  complicated and time-consuming, as we need to reason over misaligned
144-   //  channels. Hence we expand it first to minimize cost. Also, in this first
145-   //  dimension, we copy over the input texure to the output. In subsequent
146-   //  dimensions, we read and write from the same tensor.
147- 
148-   if  (int64_t  channel_repeat = dim_at<kChannel4D >(repeats);
149-       channel_repeat == 1 ) {
150-     //  If no repeat, short-cut to a direct copy
151-     utils::ivec4 src_offset{0 , 0 , 0 , 0 };
152-     utils::ivec4 dst_offset{0 , 0 , 0 , 0 };
153- 
154-     add_copy_offset_node (
155-         graph, in, running_range, src_offset, dst_offset, out, false , false );
156- 
157-   } else  {
158-     add_repeat_channel_node (graph, in, channel_repeat, out, running_range);
159-   }
160- 
161-   //  TODO: refactor width, height, and batch into a common helper function.
162-   //  Width
163-   if  (int64_t  width_repeat = dim_at<kWidth4D >(repeats); width_repeat > 1 ) {
164-     utils::ivec4 src_offset{0 , 0 , 0 , 0 };
165- 
166-     for  (int  i = 1 ; i < width_repeat; ++i) {
167-       utils::ivec4 dst_offset{i * dim_at<kWidth4D >(in_sizes), 0 , 0 , 0 };
168- 
169-       add_copy_offset_node (
170-           graph, out, running_range, src_offset, dst_offset, out, true , false );
171-     }
172- 
173-     running_range[0 ] = running_range[0 ] * width_repeat;
174-   }
175- 
176-   //  Height
177-   if  (int64_t  height_repeat = dim_at<kHeight4D >(repeats); height_repeat > 1 ) {
178-     utils::ivec4 src_offset{0 , 0 , 0 , 0 };
179- 
180-     for  (int  i = 1 ; i < height_repeat; ++i) {
181-       utils::ivec4 dst_offset = {0 , i * dim_at<kHeight4D >(in_sizes), 0 , 0 };
182- 
183-       add_copy_offset_node (
184-           graph, out, running_range, src_offset, dst_offset, out, true , false );
185-     }
186- 
187-     running_range[1 ] = running_range[1 ] * height_repeat;
188-   }
189- 
190-   //  Batch
191-   if  (int64_t  batch_repeat = dim_at<kBatch4D >(repeats); batch_repeat > 1 ) {
192-     utils::ivec4 src_offset{0 , 0 , 0 , 0 };
193- 
194-     for  (int  i = 1 ; i < batch_repeat; ++i) {
195-       utils::ivec4 dst_offset = {0 , 0 , i * running_range[2 ], 0 };
196- 
197-       add_copy_offset_node (
198-           graph, out, running_range, src_offset, dst_offset, out, true , false );
199-     }
200- 
201-     running_range[2 ] = running_range[2 ] * batch_repeat;
202-   }
72+   const  utils::ivec4 src_offset{
73+       dim_at<kWidth4D >(t_in->sizes ()),
74+       dim_at<kHeight4D >(t_in->sizes ()),
75+       dim_at<kChannel4D >(t_in->sizes ()),
76+       dim_at<kBatch4D >(t_in->sizes ())};
77+   const  utils::ivec4 dst_offset{
78+       dim_at<kWidth4D >(repeats),
79+       dim_at<kHeight4D >(repeats),
80+       dim_at<kChannel4D >(repeats),
81+       dim_at<kBatch4D >(repeats)};
82+   add_copy_packed_dim_offset_node (
83+       graph, in, t_out->logical_limits (), src_offset, dst_offset, out, true );
20384}
20485
20586void  repeat (ComputeGraph& graph, const  std::vector<ValueRef>& args) {
0 commit comments