@@ -48,18 +48,26 @@ int AutocastLongInputs(
48
48
auto dtype = dtype_input->second .value ();
49
49
// Currently, we do not autocast inputs for which the determined type is not long
50
50
if (dtype != at::kLong ) {
51
+ LOG_DEBUG (
52
+ " Skipping autocast for tensor " << input->debugName () << " , since its dtype is " << dtype
53
+ << " and not at::kLong" );
51
54
continue ;
52
55
}
53
56
54
57
LOG_DEBUG (" Inserting aten::to casting " << input->debugName () << " to dtype " << dtype);
55
58
56
59
// Generate cast node sending input tensors to the inferred or specified datatype (long)
60
+ torch::jit::Value *const_false, *cuda, *none_val;
61
+ if (num_autocasts == 0 ) {
62
+ // Only generate constants once and reuse for all autocasts
63
+ const_false = g->insertConstant (0 );
64
+ const_false->setType (torch::jit::BoolType::get ());
65
+ cuda = g->insertConstant (target_device_name);
66
+ cuda->setType (torch::jit::DeviceObjType::get ());
67
+ none_val = g->insertNode (g->createNone ())->output ();
68
+ }
69
+
57
70
auto const_type = g->insertConstant (dtype);
58
- auto const_false = g->insertConstant (0 );
59
- const_false->setType (torch::jit::BoolType::get ());
60
- auto cuda = g->insertConstant (target_device_name);
61
- cuda->setType (torch::jit::DeviceObjType::get ());
62
- auto none_val = g->insertNode (g->createNone ())->output ();
63
71
auto cast_node = g->create (torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
64
72
65
73
// Replace all uses of the original tensor with that of the casted tensor
@@ -73,12 +81,16 @@ int AutocastLongInputs(
73
81
}
74
82
}
75
83
76
- LOG_WARNING (
77
- " Input tensors to this Torch-TRT engine may have their data types in-place modified "
78
- << " if the type does not match the determined required type for TRT. To disable this "
79
- << " automatic casting, specify an Input dtype other than Long" );
84
+ LOG_GRAPH (" Inserted " << num_autocasts << " autocasts" );
80
85
81
- LOG_GRAPH (" Graph after Autocast: " << *g);
86
+ if (num_autocasts > 0 ) {
87
+ LOG_WARNING (
88
+ " Data types for input tensors have been modified by inserting "
89
+ << " aten::to operations which cast INT64 inputs to INT32. "
90
+ << " To disable this, please recompile using INT32 inputs" );
91
+
92
+ LOG_GRAPH (" Graph after Autocast: " << *g);
93
+ }
82
94
83
95
return num_autocasts;
84
96
}
0 commit comments