@@ -22,65 +22,68 @@ void add_cat_default_node(
2222 ValueRef dim_ref,
2323 ValueRef out) {
2424 ValueListPtr input_list = graph.get_value_list (in_list_ref);
25-
26- for (ValueRef input_ref : *input_list) {
27- vTensorPtr t_in = graph.get_tensor (input_ref);
28- VK_CHECK_COND (check_packed_dim_is (*t_in, WHCN::kChannelsDim ));
29- }
30-
3125 int64_t dim = graph.extract_scalar <int64_t >(dim_ref);
3226 vTensorPtr t_out = graph.get_tensor (out);
3327
28+ const auto packed_dim = t_out->packed_dim ();
29+ const auto packed_dim_index = static_cast <DimIndex>(kWidth4D - packed_dim);
30+
3431 DimIndex dim_index = normalize_to_dim_index (*t_out, dim);
32+ // Index of dimension to be concatenated in (w, h, c * b) coordinate system
33+ const auto dim_xyz_index = std::min (2 , -dim_index - 1 );
3534
36- // TODO: Find ways to factor out the similar code for width, height, and batch
37- if (dim_index == kWidth4D ) {
38- utils::ivec3 src_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
39- utils::ivec3 dst_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
35+ if (dim_index > kWidth4D || dim_index < kBatch4D ) {
36+ VK_THROW (" Unexpected value of dim_index=" , dim_index);
37+ }
4038
41- for (ValueRef input_ref : *input_list) {
42- vTensorPtr t_in = graph.get_tensor (input_ref);
43- utils::ivec3 range = t_in->logical_limits ();
44- add_copy_offset_node (
45- graph, input_ref, range, src_offset, dst_offset, out);
46- dst_offset[0 ] += range[0 ];
47- }
39+ utils::ivec4 src_offset = utils::make_ivec4 ({0 , 0 , 0 , 0 }, false );
40+ utils::ivec4 dst_offset = utils::make_ivec4 ({0 , 0 , 0 , 0 }, false );
4841
49- } else if (dim_index == kHeight4D ) {
50- utils::ivec3 src_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
51- utils::ivec3 dst_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
42+ const bool is_concat_channel = (dim_index == kChannel4D );
5243
53- for (ValueRef input_ref : *input_list) {
54- vTensorPtr t_in = graph.get_tensor (input_ref);
55- utils::ivec3 range = t_in->logical_limits ();
56- add_copy_offset_node (
57- graph, input_ref, range, src_offset, dst_offset, out);
58- dst_offset[1 ] += range[1 ];
59- }
60- } else if (dim_index == kBatch4D ) {
61- utils::ivec3 src_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
62- utils::ivec3 dst_offset = utils::make_ivec3 ({0 , 0 , 0 }, false );
44+ // if concatenating channels
45+ if (is_concat_channel) {
46+ // set destination offset w as channel size of the output tensor
47+ dst_offset[3 ] = dim_at (t_out->sizes (), kChannel4D );
48+ }
6349
64- for (ValueRef input_ref : *input_list) {
65- vTensorPtr t_in = graph.get_tensor (input_ref);
66- utils::ivec3 range = t_in->logical_limits ();
50+ for (ValueRef input_ref : *input_list) {
51+ const vTensorPtr t_in = graph.get_tensor (input_ref);
52+ const utils::ivec3 range = t_in->logical_limits ();
53+ const auto in_channel_size = dim_at (t_in->sizes (), kChannel4D );
54+ // if concatenating same dimension as the packed dimension
55+ if (dim_index == packed_dim_index) {
56+ // if concatenating channels, use add_copy_channel_offset_node function as
57+ // add_copy_packed_dim_offset_node does not support channel packing
58+ if (is_concat_channel) {
59+ add_copy_channel_offset_node (
60+ graph,
61+ input_ref,
62+ in_channel_size,
63+ src_offset[2 ],
64+ dst_offset[2 ],
65+ out);
66+ dst_offset[dim_xyz_index] += in_channel_size;
67+ } else {
68+ // src_offset[3] is not used now but will be used in the future when
69+ // add_copy_packed_dim_offset_node will support channel packing
70+ //
71+ // set source offset w as channel size of the output tensor if
72+ // concatenating channels
73+ src_offset[3 ] = is_concat_channel ? in_channel_size : 0 ;
74+ add_copy_packed_dim_offset_node (
75+ graph, input_ref, range, src_offset, dst_offset, out);
76+ dst_offset[dim_xyz_index] += dim_at (t_in->sizes (), packed_dim_index);
77+ }
78+ } else {
79+ // set source offset w as channel size of the output tensor if
80+ // concatenating channels
81+ src_offset[3 ] = is_concat_channel ? in_channel_size : 0 ;
6782 add_copy_offset_node (
6883 graph, input_ref, range, src_offset, dst_offset, out);
69- dst_offset[2 ] += range[2 ];
84+ dst_offset[dim_xyz_index] +=
85+ is_concat_channel ? in_channel_size : range[dim_xyz_index];
7086 }
71- } else if (dim_index == kChannel4D ) {
72- int32_t src_offset = 0 ;
73- int32_t dst_offset = 0 ;
74-
75- for (ValueRef input_ref : *input_list) {
76- vTensorPtr t_in = graph.get_tensor (input_ref);
77- int32_t range = dim_at (t_in->sizes (), kChannel4D );
78- add_copy_channel_offset_node (
79- graph, input_ref, range, src_offset, dst_offset, out);
80- dst_offset += range;
81- }
82- } else {
83- VK_THROW (" Unexpected value of dim_index=" , dim_index);
8487 }
8588}
8689
0 commit comments