@@ -24,28 +24,28 @@ auto prim_registrations =
24
24
RegisterNodeEvaluators ()
25
25
.evaluator(
26
26
{torch::jit::prim::Constant,
27
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
27
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
28
28
if (n->output ()->type ()->kind () == at::FunctionType::Kind) {
29
29
return {};
30
30
}
31
31
return evaluators::toIValue (n->output ());
32
32
}})
33
33
.evaluator(
34
34
{torch::jit::prim::NumToTensor,
35
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
35
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
36
36
return evaluators::scalar_to_tensor (args.at (n->input (0 )).IValue ()->toScalar ());
37
37
}})
38
38
.evaluator(
39
39
{torch::jit::prim::ListUnpack,
40
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
40
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
41
41
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
42
42
const torch::jit::IValue* outputs = args.at (n->input ()).IValue ();
43
43
auto outputVec = outputs->toList ().vec ();
44
44
return std::move (c10::ivalue::Tuple::create (outputVec));
45
45
}})
46
46
.evaluator(
47
47
{torch::jit::prim::ListConstruct,
48
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
48
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
49
49
const auto num_inputs = n->inputs ().size ();
50
50
if (constTypesOnly (args)) {
51
51
c10::ListTypePtr lt = n->output ()->type ()->expect <c10::ListType>();
@@ -103,8 +103,14 @@ auto prim_registrations =
103
103
if (args.at (in).IValue ()->isNone ()) {
104
104
auto ival = torch::jit::IValue ();
105
105
list.emplace_back (std::move (ival));
106
+ } else if (args.at (in).IValue ()->isInt ()) {
107
+ auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const (ctx, torch::tensor (args.at (in).unwrapToInt ()));
108
+ auto tensor_holder = TensorContainer ();
109
+ tensor_holder.hold_tensor (itensor);
110
+ auto ival = c10::IValue (std::move (c10::make_intrusive<TensorContainer>(tensor_holder)));
111
+ list.emplace_back (std::move (ival));
106
112
} else {
107
- list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
113
+ list.emplace_back (std::move (args.at (in).unwrapToTensor ()));
108
114
}
109
115
}
110
116
}
@@ -113,7 +119,7 @@ auto prim_registrations =
113
119
}})
114
120
.evaluator(
115
121
{c10::Symbol::fromQualString (" prim::dtype" ),
116
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
122
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
117
123
auto input = args.at (n->input (0 ));
118
124
if (input.isITensor ()) {
119
125
auto trt_dtype = input.ITensor ()->getType ();
@@ -136,7 +142,7 @@ auto prim_registrations =
136
142
})})
137
143
.evaluator(
138
144
{c10::Symbol::fromQualString (" prim::min" ),
139
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
145
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
140
146
if (n->inputs ().size () == 1 ) {
141
147
auto a = args.at (n->input (0 )).unwrapToIntList ();
142
148
int64_t min = std::numeric_limits<int64_t >::max ();
@@ -198,7 +204,7 @@ auto prim_registrations =
198
204
})})
199
205
.evaluator(
200
206
{c10::Symbol::fromQualString (" prim::max" ),
201
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
207
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
202
208
if (n->inputs ().size () == 1 ) {
203
209
auto a = args.at (n->input (0 )).unwrapToIntList ();
204
210
int64_t max = std::numeric_limits<int64_t >::min ();
@@ -260,7 +266,7 @@ auto prim_registrations =
260
266
})})
261
267
.evaluator(
262
268
{c10::Symbol::fromQualString (" prim::shape" ),
263
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
269
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264
270
LOG_WARNING (" There may be undefined behavior using dynamic shape and prim::shape" );
265
271
auto tensor_var = args.at (n->input (0 ));
266
272
if (tensor_var.isITensor ()) {
@@ -274,7 +280,7 @@ auto prim_registrations =
274
280
EvalOptions ().validSchemas ({" prim::shape(Tensor a) -> (int[])" })})
275
281
.evaluator(
276
282
{torch::jit::prim::TupleConstruct,
277
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
283
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
278
284
c10::IValue tuple = c10::ivalue::Tuple::create ();
279
285
std::vector<c10::IValue> elems;
280
286
for (auto in : n->inputs ()) {
@@ -292,7 +298,7 @@ auto prim_registrations =
292
298
}})
293
299
.evaluator(
294
300
{torch::jit::prim::TupleIndex,
295
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
301
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
296
302
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
297
303
auto tuple = args.at (n->input (0 )).IValue ()->toTuple ();
298
304
int64_t idx = args.at (n->input (1 )).IValue ()->toInt ();
@@ -302,24 +308,24 @@ auto prim_registrations =
302
308
EvalOptions ().validSchemas ({" prim::TupleIndex(Any tup, int i) -> (Any)" })})
303
309
.evaluator(
304
310
{torch::jit::prim::TupleUnpack,
305
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
311
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
306
312
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
307
313
auto output = args.at (n->input ()).IValue ()->toTuple ();
308
314
return c10::optional<torch::jit::IValue>(std::move (output));
309
315
}})
310
316
.evaluator(
311
317
{c10::Symbol::fromQualString (" prim::unchecked_cast" ),
312
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
313
319
return *(args.at (n->input (0 )).IValue ());
314
320
}})
315
321
.evaluator(
316
322
{c10::Symbol::fromQualString (" prim::Uninitialized" ),
317
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318
324
return c10::IValue::uninitialized ();
319
325
}})
320
326
.evaluator(
321
327
{c10::Symbol::fromQualString (" prim::RaiseException" ),
322
- [](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
328
+ [](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323
329
auto exception = args.at (n->input (0 )).IValue ();
324
330
TORCHTRT_THROW_ERROR (" Error from TorchScript: " << *exception);
325
331
return {};
@@ -328,4 +334,4 @@ auto prim_registrations =
328
334
} // namespace evaluators
329
335
} // namespace conversion
330
336
} // namespace core
331
- } // namespace torch_tensorrt
337
+ } // namespace torch_tensorrt
0 commit comments