Skip to content

Commit 40dfa81

Browse files
committed
Merge pull request #1851 from pytorch/exp_size
feat: Wrap dynamic size handling in a compilation flag
1 parent ba569aa commit 40dfa81

File tree

15 files changed

+90
-17
lines changed

15 files changed

+90
-17
lines changed

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ struct BuilderSettings {
2323
bool refit = false;
2424
bool debug = false;
2525
bool truncate_long_and_double = false;
26+
bool allow_shape_tensors = false;
2627
ir::Device device;
2728
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
2829
nvinfer1::IInt8Calibrator* calibrator = nullptr;

core/conversion/evaluators/aten.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,12 @@ auto aten_registrations TORCHTRT_UNUSED =
270270
if (tensor_var.isITensor()) {
271271
auto tensor = tensor_var.ITensor();
272272
if (ctx->input_is_dynamic) {
273-
return dynamic_size_layer(ctx, n, args);
273+
if (ctx->settings.allow_shape_tensors) {
274+
return dynamic_size_layer(ctx, n, args);
275+
} else {
276+
LOG_WARNING(
277+
"There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors");
278+
}
274279
}
275280
return util::toVec(tensor->getDimensions());
276281
} else if (tensor_var.IValue()->isTensor()) {
@@ -286,7 +291,12 @@ auto aten_registrations TORCHTRT_UNUSED =
286291
auto dim = args.at(n->input(1)).unwrapToInt();
287292
if (tensor_var.isITensor()) {
288293
if (ctx->input_is_dynamic) {
289-
return dynamic_size_layer(ctx, n, args);
294+
if (ctx->settings.allow_shape_tensors) {
295+
return dynamic_size_layer(ctx, n, args);
296+
} else {
297+
LOG_WARNING(
298+
"There may be undefined behavior using dynamic shape and aten::size without setting allow_shape_tensors");
299+
}
290300
}
291301
auto tensor = tensor_var.ITensor();
292302
auto dims = util::toVec(tensor->getDimensions());
@@ -605,7 +615,8 @@ auto aten_registrations TORCHTRT_UNUSED =
605615
.evaluator(
606616
{c10::Symbol::fromQualString("aten::numel"),
607617
[](ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
608-
LOG_WARNING("There may be undefined behavior using dynamic shape and aten::numel");
618+
LOG_WARNING(
619+
"There may be undefined behavior using dynamic shape and aten::numel without setting allow_shape_tensors");
609620
auto tensor_var = args.at(n->input(0));
610621
if (tensor_var.isITensor()) {
611622
auto tensor = tensor_var.ITensor();

core/conversion/evaluators/eval_util.cpp

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ nvinfer1::ITensor* index_layer(
3232
c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kwargs& args) {
3333
LOG_DEBUG("Using dynamic version of aten::size evaluator");
3434
auto in = args.at(n->input(0)).ITensorOrFreeze(ctx);
35-
LOG_DEBUG("Input dimensions: " << in->getDimensions());
35+
auto input_dims = in->getDimensions();
36+
LOG_DEBUG("Input dimensions: " << input_dims);
37+
3638
auto shape_layer = ctx->net->addShape(*in);
3739
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
3840
auto shape_1d_tensor = shape_layer->getOutput(0);
@@ -44,15 +46,31 @@ c10::IValue dynamic_size_layer(ConversionCtx* ctx, const torch::jit::Node* n, kw
4446
dim = dim < 0 ? dim + maxDim : dim;
4547
LOG_DEBUG("Dimension to select: " << dim);
4648
shape_1d_tensor = index_layer(ctx, n, shape_1d_tensor, dim);
47-
}
49+
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
4850

49-
LOG_DEBUG("Output tensor shape: " << shape_1d_tensor->getDimensions());
51+
auto tensor_holder = TensorContainer();
52+
tensor_holder.hold_tensor(shape_1d_tensor);
53+
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
5054

51-
auto tensor_holder = TensorContainer();
52-
tensor_holder.hold_tensor(shape_1d_tensor);
53-
auto shape_1d_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
55+
return shape_1d_ivalue;
5456

55-
return shape_1d_ivalue;
57+
} else {
58+
auto input_size = c10::impl::GenericList(c10::AnyType::get());
59+
// Only express the dynamic dimension with a shape layer output.
60+
// The static dimensions are preserved in the input size.
61+
for (int32_t i = 0; i < input_dims.nbDims; i++) {
62+
if (input_dims.d[i] == -1) {
63+
auto dynamic_dim_tensor = index_layer(ctx, n, shape_1d_tensor, i);
64+
auto dynamic_dim_holder = TensorContainer();
65+
dynamic_dim_holder.hold_tensor(dynamic_dim_tensor);
66+
auto dynamic_dim_ivalue = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(dynamic_dim_holder)));
67+
input_size.emplace_back(std::move(dynamic_dim_ivalue));
68+
} else {
69+
input_size.emplace_back(input_dims.d[i]);
70+
}
71+
}
72+
return c10::IValue(input_size);
73+
}
5674
}
5775

5876
int64_t normalizeIndex(int64_t idx, int64_t list_size) {

cpp/bin/torchtrtc/main.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ int main(int argc, char** argv) {
168168
"Truncate weights that are provided in 64bit to 32bit (Long, Double to Int, Float)",
169169
{"truncate", "truncate-long-double", "truncate-64bit"});
170170

171+
args::Flag allow_shape_tensors(
172+
parser,
173+
"allow-shape-tensors",
174+
"(Experimental) Allow aten::size to output shape tensors using IShapeLayer in TensorRT",
175+
{"allow-shape-tensors"});
176+
171177
args::Flag save_engine(
172178
parser,
173179
"save_engine",
@@ -443,6 +449,10 @@ int main(int argc, char** argv) {
443449
compile_settings.truncate_long_and_double = true;
444450
}
445451

452+
if (allow_shape_tensors) {
453+
compile_settings.allow_shape_tensors = true;
454+
}
455+
446456
torch::jit::Module mod;
447457
try {
448458
// Deserialize the ScriptModule from a file using torch::jit::load().

cpp/include/torch_tensorrt/torch_tensorrt.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,11 @@ struct CompileSpec {
791791
*/
792792
bool truncate_long_and_double = false;
793793

794+
/**
795+
* Allow shape tensors (from IShape layer) in the graph
796+
*/
797+
bool allow_shape_tensors = false;
798+
794799
/**
795800
* Target Device
796801
*/

cpp/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ torchtrt::core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9090
internal.convert_info.engine_settings.refit = external.refit;
9191
internal.convert_info.engine_settings.debug = external.debug;
9292
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
93+
internal.convert_info.engine_settings.allow_shape_tensors = external.allow_shape_tensors;
9394
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9495
internal.lower_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9596
internal.partitioning_info.target_device.allow_gpu_fallback = external.device.allow_gpu_fallback;

py/torch_tensorrt/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ void RegisterTRTCompileSpec() {
8484
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, dla_global_dram_size);
8585
ADD_FIELD_GET_SET_REGISTRATION(
8686
TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, truncate_long_and_double);
87+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, torch_tensorrt::pyapi::CompileSpec, allow_shape_tensors);
8788
}
8889

8990
struct TRTTSRegistrations {

py/torch_tensorrt/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
373373
info.partitioning_info.truncate_long_and_double = truncate_long_and_double;
374374
info.lower_info.forced_fallback_modules = torch_fallback.forced_fallback_modules;
375375
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
376+
info.convert_info.engine_settings.allow_shape_tensors = allow_shape_tensors;
376377

377378
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
378379
TORCHTRT_CHECK(num_avg_timing_iters >= 0, "num_avg_timing_iters must be 0 or greater");
@@ -423,6 +424,7 @@ std::string CompileSpec::stringify() {
423424
ss << " \"DLA Local DRAM Size\": " << dla_local_dram_size << std::endl;
424425
ss << " \"DLA Global DRAM Size\": " << dla_global_dram_size << std::endl;
425426
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
427+
ss << " \"Allow Shape tensors\": " << allow_shape_tensors << std::endl;
426428
ss << " \"Torch Fallback\": " << torch_fallback.to_str();
427429
ss << "}";
428430
return ss.str();

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ struct CompileSpec : torch::CustomClassHolder {
167167
ADD_FIELD_GET_SET(dla_local_dram_size, int64_t);
168168
ADD_FIELD_GET_SET(dla_global_dram_size, int64_t);
169169
ADD_FIELD_GET_SET(truncate_long_and_double, bool);
170+
ADD_FIELD_GET_SET(allow_shape_tensors, bool);
170171
ADD_FIELD_GET_SET(device, Device);
171172
ADD_FIELD_GET_SET(torch_fallback, TorchFallback);
172173
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);
@@ -180,6 +181,7 @@ struct CompileSpec : torch::CustomClassHolder {
180181
bool refit = false;
181182
bool debug = false;
182183
bool truncate_long_and_double = false;
184+
bool allow_shape_tensors = false;
183185
Device device;
184186
TorchFallback torch_fallback;
185187
EngineCapability capability = EngineCapability::kDEFAULT;

py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,8 @@ PYBIND11_MODULE(_C, m) {
371371
.def_readwrite("dla_local_dram_size", &CompileSpec::dla_local_dram_size)
372372
.def_readwrite("dla_global_dram_size", &CompileSpec::dla_global_dram_size)
373373
.def_readwrite("torch_fallback", &CompileSpec::torch_fallback)
374-
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double);
374+
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double)
375+
.def_readwrite("allow_shape_tensors", &CompileSpec::allow_shape_tensors);
375376

376377
py::class_<TorchFallback>(ts_sub_mod, "TorchFallback")
377378
.def(py::init<>())

0 commit comments

Comments
 (0)