Skip to content

Commit ccad996

Browse files
committed
fix: Implement duality support for evaluators
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 24172f0 commit ccad996

File tree

9 files changed

+1171
-1155
lines changed

9 files changed

+1171
-1155
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
6868
return {};
6969
}
7070
}
71-
auto eval = evaluators::EvalNode(n, eval_args);
71+
auto eval = evaluators::EvalNode(ctx, n, eval_args);
7272
return eval;
7373
}
7474

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ cc_library(
6262
"impl/constant_pad.cpp",
6363
"impl/conv_deconv.cpp",
6464
"impl/cumsum.cpp",
65+
"impl/dual_ops.cpp",
6566
"impl/element_wise.cpp",
6667
"impl/expand.cpp",
6768
"impl/interpolate.cpp",

core/conversion/converters/impl/shuffle.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,14 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
6868
{"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
6969
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
7070
auto in = args[0].ITensorOrFreeze(ctx);
71+
std::cout << "====1====" << std::endl;
7172
auto in_shape = util::toVec(in->getDimensions());
73+
std::cout << "====2====" << std::endl;
7274
std::vector<int64_t> new_shape;
7375
if (ctx->input_is_dynamic) {
76+
std::cout << "====3====" << std::endl;
7477
new_shape = util::toVec(args[1].unwrapToIntList().vec());
78+
std::cout << "====4====" << std::endl;
7579
int nbDynamicDims = 0;
7680
for (size_t i = 0; i < new_shape.size(); i++) {
7781
if (in_shape[i] == -1)
@@ -82,9 +86,10 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
8286
"Resize is currently not supported when target shape contains more than one dynamic dimension");
8387
}
8488
} else {
89+
std::cout << "====5====" << std::endl;
8590
new_shape = torch::reshape(torch::rand(in_shape), args[1].unwrapToIntList().vec()).sizes().vec();
8691
}
87-
92+
std::cout << "====6====" << std::endl;
8893
auto shuffle = ctx->net->addShuffle(*in);
8994
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);
9095
shuffle->setReshapeDimensions(util::toDims(new_shape));

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ std::vector<std::string> getEvaluatorList() {
114114
return get_evaluator_registry().GetRegisteredEvaluatorList();
115115
}
116116

117-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args) {
117+
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
118118
auto evaluator = get_evaluator_registry().GetEvaluator(n);
119-
return evaluator(n, args);
119+
return evaluator(ctx, n, args);
120120
}
121121

122122
void register_node_evaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {

core/conversion/evaluators/aten.cpp

Lines changed: 85 additions & 85 deletions
Large diffs are not rendered by default.

core/conversion/evaluators/eval_macros.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#define DEFINE_GENERIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
66
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
77
{c10::Symbol::fromQualString(node_kind), \
8-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
8+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
99
if (args.at(n->input(0)).IValue()->isInt()) { \
1010
auto a = args.at(n->input(0)).unwrapToInt(); \
1111
if (args.at(n->input(1)).IValue()->isInt()) { \
@@ -80,7 +80,7 @@
8080
#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
8181
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
8282
{c10::Symbol::fromQualString(node_kind), \
83-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
83+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
8484
if (args.at(n->input(0)).IValue()->isInt()) { \
8585
auto a = args.at(n->input(0)).unwrapToInt(); \
8686
if (args.at(n->input(1)).IValue()->isInt()) { \
@@ -127,7 +127,7 @@
127127
#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
128128
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
129129
{c10::Symbol::fromQualString(node_name), \
130-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
130+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
131131
auto a = args.at(n->input(0)).unwrapTo<type>(); \
132132
auto b = args.at(n->input(1)).unwrapTo<type>(); \
133133
return operation; \

core/conversion/evaluators/evaluators.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include "torch/csrc/jit/ir/ir.h"
88

99
#include "core/conversion/tensorcontainer/TensorContainer.h"
10+
#include "core/conversion/conversionctx/ConversionCtx.h"
11+
#include "core/conversion/converters/converter_util.h"
1012
#include "core/conversion/var/Var.h"
1113

1214
namespace torch_tensorrt {
@@ -33,7 +35,7 @@ inline bool constTypesOnly(kwargs& args) {
3335
// to use the node itself to pull out arguments.
3436
// This means that you should iterate over node inputs vs. the args
3537
// when writing evaluators
36-
typedef std::function<c10::optional<torch::jit::IValue>(const torch::jit::Node*, kwargs&)> NodeEvaluator;
38+
typedef std::function<c10::optional<torch::jit::IValue>(ConversionCtx*, const torch::jit::Node*, kwargs&)> NodeEvaluator;
3739

3840
struct EvalOptions {
3941
std::set<c10::TypePtr> blacklisted_output_types;
@@ -72,7 +74,7 @@ struct EvalRegistration {
7274
: kind(_kind), evaluator(_evaluator), options(_options){};
7375
};
7476

75-
c10::optional<torch::jit::IValue> EvalNode(const torch::jit::Node* n, kwargs& args);
77+
c10::optional<torch::jit::IValue> EvalNode(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args);
7678
bool shouldEvalAtConversionTime(const torch::jit::Node* n);
7779
std::vector<std::string> getEvaluatorList();
7880
void register_node_evaluator(torch::jit::NodeKind node_kind, NodeEvaluator evaluator);

core/conversion/evaluators/prim.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,28 +24,28 @@ auto prim_registrations =
2424
RegisterNodeEvaluators()
2525
.evaluator(
2626
{torch::jit::prim::Constant,
27-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
27+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
2828
if (n->output()->type()->kind() == at::FunctionType::Kind) {
2929
return {};
3030
}
3131
return evaluators::toIValue(n->output());
3232
}})
3333
.evaluator(
3434
{torch::jit::prim::NumToTensor,
35-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
35+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3636
return evaluators::scalar_to_tensor(args.at(n->input(0)).IValue()->toScalar());
3737
}})
3838
.evaluator(
3939
{torch::jit::prim::ListUnpack,
40-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
40+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4141
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
4242
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
4343
auto outputVec = outputs->toList().vec();
4444
return std::move(c10::ivalue::Tuple::create(outputVec));
4545
}})
4646
.evaluator(
4747
{torch::jit::prim::ListConstruct,
48-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
48+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
4949
const auto num_inputs = n->inputs().size();
5050
if (constTypesOnly(args)) {
5151
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
@@ -103,8 +103,14 @@ auto prim_registrations =
103103
if (args.at(in).IValue()->isNone()) {
104104
auto ival = torch::jit::IValue();
105105
list.emplace_back(std::move(ival));
106+
} else if (args.at(in).IValue()->isInt()) {
107+
auto itensor = torch_tensorrt::core::conversion::converters::tensor_to_const(ctx, torch::tensor(args.at(in).unwrapToInt()));
108+
auto tensor_holder = TensorContainer();
109+
tensor_holder.hold_tensor(itensor);
110+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
111+
list.emplace_back(std::move(ival));
106112
} else {
107-
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
113+
list.emplace_back(std::move(args.at(in).unwrapToTensor()));
108114
}
109115
}
110116
}
@@ -113,7 +119,7 @@ auto prim_registrations =
113119
}})
114120
.evaluator(
115121
{c10::Symbol::fromQualString("prim::dtype"),
116-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
122+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
117123
auto input = args.at(n->input(0));
118124
if (input.isITensor()) {
119125
auto trt_dtype = input.ITensor()->getType();
@@ -136,7 +142,7 @@ auto prim_registrations =
136142
})})
137143
.evaluator(
138144
{c10::Symbol::fromQualString("prim::min"),
139-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
145+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
140146
if (n->inputs().size() == 1) {
141147
auto a = args.at(n->input(0)).unwrapToIntList();
142148
int64_t min = std::numeric_limits<int64_t>::max();
@@ -198,7 +204,7 @@ auto prim_registrations =
198204
})})
199205
.evaluator(
200206
{c10::Symbol::fromQualString("prim::max"),
201-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
207+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
202208
if (n->inputs().size() == 1) {
203209
auto a = args.at(n->input(0)).unwrapToIntList();
204210
int64_t max = std::numeric_limits<int64_t>::min();
@@ -260,7 +266,7 @@ auto prim_registrations =
260266
})})
261267
.evaluator(
262268
{c10::Symbol::fromQualString("prim::shape"),
263-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
269+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
264270
LOG_WARNING("There may be undefined behavior using dynamic shape and prim::shape");
265271
auto tensor_var = args.at(n->input(0));
266272
if (tensor_var.isITensor()) {
@@ -274,7 +280,7 @@ auto prim_registrations =
274280
EvalOptions().validSchemas({"prim::shape(Tensor a) -> (int[])"})})
275281
.evaluator(
276282
{torch::jit::prim::TupleConstruct,
277-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
283+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
278284
c10::IValue tuple = c10::ivalue::Tuple::create();
279285
std::vector<c10::IValue> elems;
280286
for (auto in : n->inputs()) {
@@ -292,7 +298,7 @@ auto prim_registrations =
292298
}})
293299
.evaluator(
294300
{torch::jit::prim::TupleIndex,
295-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
301+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
296302
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
297303
auto tuple = args.at(n->input(0)).IValue()->toTuple();
298304
int64_t idx = args.at(n->input(1)).IValue()->toInt();
@@ -302,24 +308,24 @@ auto prim_registrations =
302308
EvalOptions().validSchemas({"prim::TupleIndex(Any tup, int i) -> (Any)"})})
303309
.evaluator(
304310
{torch::jit::prim::TupleUnpack,
305-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
311+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
306312
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
307313
auto output = args.at(n->input()).IValue()->toTuple();
308314
return c10::optional<torch::jit::IValue>(std::move(output));
309315
}})
310316
.evaluator(
311317
{c10::Symbol::fromQualString("prim::unchecked_cast"),
312-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
313319
return *(args.at(n->input(0)).IValue());
314320
}})
315321
.evaluator(
316322
{c10::Symbol::fromQualString("prim::Uninitialized"),
317-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
318324
return c10::IValue::uninitialized();
319325
}})
320326
.evaluator(
321327
{c10::Symbol::fromQualString("prim::RaiseException"),
322-
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
328+
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
323329
auto exception = args.at(n->input(0)).IValue();
324330
TORCHTRT_THROW_ERROR("Error from TorchScript: " << *exception);
325331
return {};
@@ -328,4 +334,4 @@ auto prim_registrations =
328334
} // namespace evaluators
329335
} // namespace conversion
330336
} // namespace core
331-
} // namespace torch_tensorrt
337+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)