@@ -108,27 +108,15 @@ void add_slice_tensor_copy_node(
108108 spec_vars));
109109
110110 } else {
111- // GPU's coordinate is in x, y, z
112- int64_t gpu_dim = -1 ;
113- int64_t in_channel_stride = 1 ;
114- if (dim_index == kWidth4D ) {
115- gpu_dim = 0 ; // width: x dimension in gpu
116- VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
117- } else if (dim_index == kHeight4D ) {
118- gpu_dim = 1 ; // height: y dimension
119- VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
120- } else if (dim_index == kChannel4D ) {
121- gpu_dim = 2 ; // channel: z dimension
122- VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
123- in_channel_stride = dim_at (in_sizes, kChannel4D );
124- } else {
125- gpu_dim = 3 ; // batch: w dimension
126-
127- in_channel_stride = dim_at (in_sizes, kChannel4D );
128- if (packed_dim_idx == kChannel4D ) {
129- // Due to channel packing, each batch value is span over stride planes
130- in_channel_stride = utils::div_up_4 (in_channel_stride);
131- }
111+ // GPU's coordinate is in x = 0, y = 1, z = 2, w = 3
112+ const int64_t gpu_dim = -(dim_index + 1 );
113+ // stride of input tensor's channel dimension
114+ int64_t in_channel_stride = dim_at (in_sizes, kChannel4D );
115+ VK_CHECK_COND (out_sizes[dim] == (1 + (end - start - 1 ) / step));
116+
117+ // Due to channel packing, each batch value is span over stride planes
118+ if (dim_index == kBatch4D && packed_dim_idx == kChannel4D ) {
119+ in_channel_stride = utils::div_up_4 (in_channel_stride);
132120 }
133121
134122 std::string kernel_name = " slice_batch_height_width" ;
0 commit comments