Skip to content

Commit 8156465

Browse files
authored
Merge pull request #407 from inocsin/double_long_ival
feat: support truncate long/double to int/float with option
2 parents b333543 + 5c1bf0c commit 8156465

File tree

9 files changed

+35
-5
lines changed

9 files changed

+35
-5
lines changed

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct BuilderSettings {
2828
bool refit = false;
2929
bool debug = false;
3030
bool strict_types = false;
31+
bool truncate_long_and_double = false;
3132
Device device;
3233
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
3334
nvinfer1::IInt8Calibrator* calibrator = nullptr;

core/conversion/var/Var.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
8989
if (isIValue()) {
9090
LOG_DEBUG(ctx->logger, "Found IValue containing object of type " << *(ptr_.ivalue->type()));
9191
}
92+
9293
TRTORCH_CHECK(
9394
isITensor() || (isIValue() && (ptr_.ivalue->isTensor() || ptr_.ivalue->isCustomClass())),
9495
"Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name());
@@ -97,11 +98,22 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
9798

9899
if (isIValue()) {
99100
if (ptr_.ivalue->isTensor()) {
100-
auto weights = converters::Weights(ctx, ptr_.ivalue->toTensor());
101-
101+
auto weights = converters::Weights();
102+
auto tensor = ptr_.ivalue->toTensor();
103+
if ((tensor.scalar_type() == at::kLong || tensor.scalar_type() == at::kDouble) && !ctx->settings.truncate_long_and_double) {
104+
TRTORCH_THROW_ERROR("Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled");
105+
} else if (tensor.scalar_type() == at::kLong && ctx->settings.truncate_long_and_double) {
106+
weights = converters::Weights(ctx, tensor.toType(at::kInt));
107+
LOG_WARNING("Truncating weight (constant in the graph) from Int64 to Int32");
108+
} else if (tensor.scalar_type() == at::kDouble && ctx->settings.truncate_long_and_double) {
109+
weights = converters::Weights(ctx, tensor.toType(at::kFloat));
110+
LOG_WARNING("Truncating weight (constant in the graph) from Float64 to Float32");
111+
} else {
112+
weights = converters::Weights(ctx, tensor);
113+
}
114+
102115
auto const_layer = ctx->net->addConstant(weights.shape, weights.data);
103116
TRTORCH_CHECK(const_layer, "Unable to freeze tensor into constant layer");
104-
105117
out = const_layer->getOutput(0);
106118

107119
std::ostringstream tensor_id;
@@ -119,7 +131,6 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
119131
}
120132

121133
LOG_DEBUG("Frozen tensor shape: " << out->getDimensions());
122-
123134
return out;
124135
}
125136

cpp/api/include/trtorch/trtorch.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,11 @@ struct TRTORCH_API CompileSpec {
258258
*/
259259
bool debug = false;
260260

261+
/**
262+
* Truncate long/double type to int/float type
263+
*/
264+
bool truncate_long_and_double = false;
265+
261266
/**
262267
* Restrict operating type to only set default operation precision
263268
* (op_precision)

cpp/api/src/compile_spec.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
9292
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;
9393
internal.convert_info.engine_settings.refit = external.refit;
9494
internal.convert_info.engine_settings.debug = external.debug;
95+
internal.convert_info.engine_settings.truncate_long_and_double = external.truncate_long_and_double;
9596
internal.convert_info.engine_settings.strict_types = external.strict_types;
9697
internal.convert_info.engine_settings.device.allow_gpu_fallback = external.device.allow_gpu_fallback;
9798
internal.convert_info.engine_settings.max_batch_size = external.max_batch_size;

py/trtorch/_compile_spec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
176176
if "max_batch_size" in compile_spec:
177177
assert type(compile_spec["max_batch_size"]) is int
178178
info.max_batch_size = compile_spec["max_batch_size"]
179+
180+
if "truncate_long_and_double" in compile_spec:
181+
assert type(compile_spec["truncate_long_and_double"]) is bool
182+
info.truncate_long_and_double = compile_spec["truncate_long_and_double"]
179183

180184
return info
181185

@@ -217,6 +221,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
217221
"num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
218222
"workspace_size": 0, # Maximum size of workspace given to TensorRT
219223
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
224+
"truncate_long_and_double": False, # Truncate long and double into int and float
220225
})
221226
}
222227

@@ -257,6 +262,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
257262
backend_spec.set_num_avg_timing_iters(parsed_spec.num_avg_timing_iters)
258263
backend_spec.set_workspace_size(parsed_spec.workspace_size)
259264
backend_spec.set_max_batch_size(parsed_spec.max_batch_size)
265+
backend_spec.set_truncate_long_and_double(parsed_spec.truncate_long_and_double)
260266
backend_spec._set_ptq_calibrator(parsed_spec._get_calibrator_handle())
261267

262268
return backend_spec

py/trtorch/csrc/register_tensorrt_classes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ void RegisterTRTCompileSpec() {
4242
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, num_avg_timing_iters);
4343
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, workspace_size);
4444
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, max_batch_size);
45+
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, truncate_long_and_double);
4546
}
4647

4748
struct TRTTSRegistrations {

py/trtorch/csrc/tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
108108
info.convert_info.engine_settings.device.gpu_id = device.gpu_id;
109109
info.convert_info.engine_settings.device.dla_core = device.dla_core;
110110
info.convert_info.engine_settings.device.allow_gpu_fallback = device.allow_gpu_fallback;
111+
info.convert_info.engine_settings.truncate_long_and_double = truncate_long_and_double;
111112

112113
info.convert_info.engine_settings.capability = toTRTEngineCapability(capability);
113114
TRTORCH_CHECK(num_min_timing_iters >= 0, "num_min_timing_iters must be 0 or greater");
@@ -143,6 +144,7 @@ std::string CompileSpec::stringify() {
143144
ss << " \"Num Avg Timing Iters\": " << num_avg_timing_iters << std::endl;
144145
ss << " \"Workspace Size\": " << workspace_size << std::endl;
145146
ss << " \"Max Batch Size\": " << max_batch_size << std::endl;
147+
ss << " \"Truncate long and double\": " << truncate_long_and_double << std::endl;
146148
ss << "}";
147149
return ss.str();
148150
}

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct CompileSpec : torch::CustomClassHolder {
115115
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
116116
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
117117
ADD_FIELD_GET_SET(workspace_size, int64_t);
118+
ADD_FIELD_GET_SET(truncate_long_and_double, bool);
118119
ADD_FIELD_GET_SET(max_batch_size, int64_t);
119120
ADD_FIELD_GET_SET(device, Device);
120121
ADD_FIELD_GET_SET(ptq_calibrator, nvinfer1::IInt8Calibrator*);
@@ -126,6 +127,7 @@ struct CompileSpec : torch::CustomClassHolder {
126127
bool refit = false;
127128
bool debug = false;
128129
bool strict_types = false;
130+
bool truncate_long_and_double = false;
129131
Device device;
130132
EngineCapability capability = EngineCapability::kDEFAULT;
131133
int64_t num_min_timing_iters = 2;

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ PYBIND11_MODULE(_C, m) {
246246
.def_readwrite("num_min_timing_iters", &CompileSpec::num_min_timing_iters)
247247
.def_readwrite("num_avg_timing_iters", &CompileSpec::num_avg_timing_iters)
248248
.def_readwrite("workspace_size", &CompileSpec::workspace_size)
249-
.def_readwrite("max_batch_size", &CompileSpec::max_batch_size);
249+
.def_readwrite("max_batch_size", &CompileSpec::max_batch_size)
250+
.def_readwrite("truncate_long_and_double", &CompileSpec::truncate_long_and_double);
250251

251252
py::class_<Device>(m, "Device")
252253
.def(py::init<>())

0 commit comments

Comments
 (0)