@@ -125,6 +125,81 @@ static auto shuffle_registrations TRTORCH_UNUSED =
125
125
auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], shuffle->getOutput (0 ));
126
126
LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
127
127
128
+ return true ;
129
+ }})
130
+ .pattern({" aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)" ,
131
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
132
+ auto self = args[0 ].ITensorOrFreeze (ctx);
133
+ auto in_shape = util::toVec (self->getDimensions ());
134
+ int64_t irank = in_shape.size ();
135
+ TRTORCH_CHECK (
136
+ irank >= 3 ,
137
+ " pixel_shuffle expects input to have at least 3 dimensions, but got input with "
138
+ << irank << " dimension(s)" );
139
+ int64_t upscale_factor = args[1 ].unwrapToInt ();
140
+ TRTORCH_CHECK (
141
+ upscale_factor > 0 ,
142
+ " pixel_shuffle expects a positive upscale_factor, but got " << upscale_factor);
143
+ int64_t upscale_factor_squared = upscale_factor * upscale_factor;
144
+
145
+ const auto NUM_NON_BATCH_DIMS = 3 ;
146
+ const auto self_sizes_batch_end = in_shape.end () - NUM_NON_BATCH_DIMS;
147
+
148
+ int64_t ic = in_shape[irank - 3 ];
149
+ int64_t ih = in_shape[irank - 2 ];
150
+ int64_t iw = in_shape[irank - 1 ];
151
+
152
+ TRTORCH_CHECK (
153
+ ic % upscale_factor_squared == 0 ,
154
+ " pixel_shuffle expects its input's 'channel' dimension to be divisible by the square of "
155
+ << " upscale_factor, but input.size(-3)=" << ic << " is not divisible by "
156
+ << upscale_factor_squared);
157
+
158
+ int64_t oc = ic / upscale_factor_squared;
159
+ int64_t oh = ih * upscale_factor;
160
+ int64_t ow = iw * upscale_factor;
161
+
162
+ // First, reshape to split the channels dim from c into 3 separate dims: (oc,
163
+ // upscale_factor, upscale_factor). This allows shuffling to be done next by
164
+ // permuting dims.
165
+ std::vector<int64_t > added_dims_shape (in_shape.begin (), self_sizes_batch_end);
166
+ added_dims_shape.insert (added_dims_shape.end (), {oc, upscale_factor, upscale_factor, ih, iw});
167
+ auto view_layer = ctx->net ->addShuffle (*self);
168
+ TRTORCH_CHECK (view_layer, " Unable to create shuffle layer from node: " << *n);
169
+ view_layer->setReshapeDimensions (util::toDims (added_dims_shape));
170
+ int64_t view_rank = added_dims_shape.size ();
171
+
172
+ // Next, shuffle by permuting the new upscale_factor dims alongside the height and width dims.
173
+ auto permutation_layer = ctx->net ->addShuffle (*view_layer->getOutput (0 ));
174
+ TRTORCH_CHECK (permutation_layer, " Unable to create shuffle layer from node: " << *n);
175
+ // std::iota is used to maintain the batch dims within the permutation.
176
+ // Eg: if added_dims_shape is {n1, n2, c, r, r, h, w}, then the new_order is {view_rank-7,
177
+ // view_rank-6, view_rank-5, view_rank-2, view_rank-4, view_rank-1, view_rank-3}
178
+ std::vector<int64_t > new_order (in_shape.begin (), self_sizes_batch_end);
179
+ std::iota (new_order.begin (), new_order.end (), 0 );
180
+ new_order.insert (
181
+ new_order.end (),
182
+ {view_rank - 5 /* oc */ ,
183
+ view_rank - 2 /* ih */ ,
184
+ view_rank - 4 /* 1st upscale_factor */ ,
185
+ view_rank - 1 /* iw */ ,
186
+ view_rank - 3 /* 2nd upscale_factor */ });
187
+ nvinfer1::Permutation permute;
188
+ std::copy (new_order.begin (), new_order.end (), permute.order );
189
+ permutation_layer->setSecondTranspose (permute);
190
+
191
+ // Finally, upscale by collapsing (ih, upscale_factor) -> a single dim (oh)
192
+ // and (iw, upscale_factor) -> a single dim (ow).
193
+ std::vector<int64_t > final_shape (in_shape.begin (), self_sizes_batch_end);
194
+ final_shape.insert (final_shape.end (), {oc, oh, ow});
195
+ auto last_view_layer = ctx->net ->addShuffle (*permutation_layer->getOutput (0 ));
196
+ TRTORCH_CHECK (last_view_layer, " Unable to create shuffle layer from node: " << *n);
197
+ last_view_layer->setReshapeDimensions (util::toDims (final_shape));
198
+ last_view_layer->setName (util::node_info (n).c_str ());
199
+
200
+ auto out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], last_view_layer->getOutput (0 ));
201
+ LOG_DEBUG (" Output tensor shape: " << out_tensor->getDimensions ());
202
+
128
203
return true ;
129
204
}});
130
205
0 commit comments