@@ -213,6 +213,21 @@ auto element_wise_registrations TRTORCH_UNUSED =
213
213
div->setName (util::node_info (n).c_str ());
214
214
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], div->getOutput (0 ));
215
215
216
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
217
+ return true ;
218
+ }})
219
+ .pattern({" aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
220
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
221
+ // TODO: Remove with functionalization
222
+ auto self = args[0 ].ITensorOrFreeze (ctx);
223
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
224
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
225
+ auto div =
226
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
227
+ TRTORCH_CHECK (div, " Unable to create div layer from node: " << *n);
228
+
229
+ div->setName (util::node_info (n).c_str ());
230
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], div->getOutput (0 ));
216
231
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
217
232
return true ;
218
233
}})
@@ -229,6 +244,21 @@ auto element_wise_registrations TRTORCH_UNUSED =
229
244
div->setName (util::node_info (n).c_str ());
230
245
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], div->getOutput (0 ));
231
246
247
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
248
+ return true ;
249
+ }})
250
+ .pattern({" aten::div_.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
251
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
252
+ // TODO: Remove with functionalization
253
+ auto self = args[0 ].ITensorOrFreeze (ctx);
254
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
255
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
256
+ auto div =
257
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kDIV , self, other, util::node_info (n));
258
+ TRTORCH_CHECK (div, " Unable to create div layer from node: " << *n);
259
+
260
+ div->setName (util::node_info (n).c_str ());
261
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], div->getOutput (0 ));
232
262
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
233
263
return true ;
234
264
}})
0 commit comments