@@ -44,8 +44,7 @@ void add_slice_tensor_copy_node(
4444 vTensorPtr t_in = graph.get_tensor (in);
4545 vTensorPtr t_out = graph.get_tensor (out);
4646
47- VK_CHECK_COND (check_packed_dim_is (*t_in, WHCN::kChannelsDim ));
48- VK_CHECK_COND (check_packed_dim_is (*t_out, WHCN::kChannelsDim ));
47+ VK_CHECK_COND (check_same_packed_dim (*t_in, *t_out));
4948
5049 // Need normalize the dim
5150 int64_t dim = graph.extract_scalar <int64_t >(dim_ref);
@@ -76,7 +75,13 @@ void add_slice_tensor_copy_node(
7675 start = normalize_idx (start, in_sizes[dim], 0 );
7776 end = normalize_idx (end, in_sizes[dim], in_sizes[dim]);
7877
79- if (dim_index == kChannel4D ) {
78+ const vkapi::SpecVarList spec_vars = {t_in->packed_dim ()};
79+
80+ const auto packed_dim_idx =
81+ static_cast <DimIndex>(DimIndex::DIM_LAST - t_in->packed_dim ());
82+
83+ // if slice dim is the same as the packed dim, we can use the channel slice
84+ if (dim_index == packed_dim_idx) {
8085 // slice by channel
8186 std::string kernel_name = " slice_channel" ;
8287 kernel_name.reserve (kShaderNameReserve );
@@ -99,26 +104,31 @@ void add_slice_tensor_copy_node(
99104 {in, vkapi::MemoryAccessType::READ}},
100105 {t_out->sizes_ubo (),
101106 t_in->sizes_ubo (),
102- graph.create_params_buffer (params)}));
107+ graph.create_params_buffer (params)},
108+ spec_vars));
103109
104110 } else {
105111 // GPU's coordinate is in x, y, z
106112 int64_t gpu_dim = -1 ;
107- int64_t stride = 1 ;
113+ int64_t in_channel_stride = 1 ;
108114 if (dim_index == kWidth4D ) {
109115 gpu_dim = 0 ; // width: x dimension in gpu
110116 VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
111117 } else if (dim_index == kHeight4D ) {
112118 gpu_dim = 1 ; // height: y dimension
113119 VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
114- } else if (dim_index == kBatch4D ) {
115- gpu_dim = 2 ; // batch: z dimension
116-
117- // Due to channel packing, each batch value is span over stride planes
118- int64_t n_channels = dim_at (in_sizes, kChannel4D );
119- stride = utils::div_up_4 (n_channels);
120+ } else if (dim_index == kChannel4D ) {
121+ gpu_dim = 2 ; // channel: z dimension
122+ VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
123+ in_channel_stride = dim_at (in_sizes, kChannel4D );
120124 } else {
121- VK_THROW (" Unexpected ncwh_dim!" );
125+ gpu_dim = 3 ; // batch: w dimension
126+
127+ in_channel_stride = dim_at (in_sizes, kChannel4D );
128+ if (packed_dim_idx == kChannel4D ) {
129+ // Due to channel packing, each batch value is span over stride planes
130+ in_channel_stride = utils::div_up_4 (in_channel_stride);
131+ }
122132 }
123133
124134 std::string kernel_name = " slice_batch_height_width" ;
@@ -137,7 +147,7 @@ void add_slice_tensor_copy_node(
137147 static_cast <int32_t >(gpu_dim),
138148 static_cast <int32_t >(start),
139149 static_cast <int32_t >(step),
140- static_cast <int32_t >(stride ),
150+ static_cast <int32_t >(in_channel_stride ),
141151 };
142152
143153 graph.execute_nodes ().emplace_back (new DispatchNode (
@@ -147,7 +157,8 @@ void add_slice_tensor_copy_node(
147157 local_size,
148158 {{out, vkapi::MemoryAccessType::WRITE},
149159 {in, vkapi::MemoryAccessType::READ}},
150- {t_out->sizes_ubo (), graph.create_params_buffer (params)}));
160+ {t_out->sizes_ubo (), graph.create_params_buffer (params)},
161+ spec_vars));
151162 }
152163}
153164
0 commit comments