Skip to content

Commit a9ab1b4

Browse files
committed
refactor(//py): Redo Python tests for new API
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 483ef59 commit a9ab1b4

21 files changed

+252
-368
lines changed

core/compiler.cpp

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -297,46 +297,48 @@ void MapInputsAndDetermineDTypes(
297297
cfg.convert_info.inputs = std::move(ir::associate_specs_with_inputs(g, cfg.inputs, static_params));
298298

299299
for (auto& in : g->inputs()) {
300-
auto est_type_opt = first_use_type_map.find(in)->second;
301-
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
302-
if (est_type_opt && !spec.dtype_is_user_defined) {
303-
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
304-
// type
305-
LOG_INFO(
306-
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
307-
<< in->debugName() << " has type " << est_type_opt.value()
308-
<< ". If this is incorrect explicitly set dtype for input and file a bug");
309-
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
310-
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
311-
// If we cannot calculate the type and the user did not define the type, then default to FP32
312-
LOG_WARNING(
313-
"Cannot infer input type from calcuations in graph for input "
314-
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
315-
spec.dtype = nvinfer1::DataType::kFLOAT;
316-
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
317-
if (!est_type_opt) {
318-
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
319-
} else {
320-
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
321-
std::stringstream ss;
322-
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
323-
ss << cfg.convert_info.inputs.find(in)->second.dtype;
324-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
325-
ss << est_type_opt.value() << std::endl;
326-
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
327-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
328-
ss << "compatibility with PyTorch's data type convention is required.\n";
329-
ss << "If you do indeed see errors at runtime either:\n";
330-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
331-
ss << "- Disable partial compilation by setting require_full_compilation to True";
332-
auto warn_str = ss.str();
333-
LOG_WARNING(warn_str);
334-
// Overwrite type map with user settings
335-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
300+
if (static_params.find(in) == static_params.end()) {
301+
ir::Input& spec = cfg.convert_info.inputs.find(in)->second;
302+
auto est_type_opt = first_use_type_map.find(in)->second;
303+
if (est_type_opt && !spec.dtype_is_user_defined) {
304+
// If we can calculate the type from the graph and the type was not defined by the user then use the calculated
305+
// type
306+
LOG_INFO(
307+
"Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
308+
<< in->debugName() << " has type " << est_type_opt.value()
309+
<< ". If this is incorrect explicitly set dtype for input and file a bug");
310+
spec.dtype = util::ScalarTypeToTRTDataType(est_type_opt.value());
311+
} else if (!est_type_opt && !spec.dtype_is_user_defined) {
312+
// If we cannot calculate the type and the user did not define the type, then default to FP32
313+
LOG_WARNING(
314+
"Cannot infer input type from calcuations in graph for input "
315+
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
316+
spec.dtype = nvinfer1::DataType::kFLOAT;
317+
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
318+
if (!est_type_opt) {
319+
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
320+
} else {
321+
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
322+
std::stringstream ss;
323+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
324+
ss << cfg.convert_info.inputs.find(in)->second.dtype;
325+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
326+
ss << est_type_opt.value() << std::endl;
327+
ss << "The compiler is going to use the user setting " << cfg.convert_info.inputs.find(in)->second.dtype;
328+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
329+
ss << "compatibility with PyTorch's data type convention is required.\n";
330+
ss << "If you do indeed see errors at runtime either:\n";
331+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
332+
ss << "- Disable partial compilation by setting require_full_compilation to True";
333+
auto warn_str = ss.str();
334+
LOG_WARNING(warn_str);
335+
// Overwrite type map with user settings
336+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
337+
}
336338
}
339+
} else {
340+
// The user defined the type so no changes are necessary
337341
}
338-
} else {
339-
// The user defined the type so no changes are necessary
340342
}
341343
}
342344
}

core/util/jit_util.cpp

Lines changed: 0 additions & 113 deletions
This file was deleted.

py/torch_tensorrt/_Device.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@ class Device(object):
1212
Defines a device that can be used to specify target devices for engines
1313
1414
Attributes:
15-
device_type (trtorch.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
15+
device_type (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
1616
gpu_id (int): Device ID for target GPU
1717
dla_core (int): Core ID for target DLA core
1818
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
1919
"""
2020

21-
device_type = None #: (trtorch.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
21+
device_type = None #: (torch_tensorrt.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
2222
gpu_id = -1 #: (int) Device ID for target GPU
2323
dla_core = -1 #: (int) Core ID for target DLA core
2424
allow_gpu_fallback = False #: (bool) Whether falling back to GPU if DLA cannot support an op should be allowed
2525

2626
def __init__(self, *args, **kwargs):
27-
""" __init__ Method for trtorch.Device
27+
""" __init__ Method for torch_tensorrt.Device
2828
2929
Device accepts one of a few construction patterns
3030

py/torch_tensorrt/_Input.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,30 @@ class Input(object):
1212
Defines an input to a module in terms of expected shape, data type and tensor format.
1313
1414
Attributes:
15-
shape_mode (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
15+
shape_mode (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped
1616
shape (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape.
1717
Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
1818
``{
1919
"min_shape": Tuple,
2020
"opt_shape": Tuple,
2121
"max_shape": Tuple
2222
}``
23-
dtype (trtorch.dtype): The expected data type of the input tensor (default: trtorch.dtype.float32)
24-
format (trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
23+
dtype (torch_tensorrt.dtype): The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
24+
format (torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
2525
"""
2626

2727
class _ShapeMode(Enum):
2828
STATIC = 0
2929
DYNAMIC = 1
3030

31-
shape_mode = None #: (trtorch.Input._ShapeMode): Is input statically or dynamically shaped
31+
shape_mode = None #: (torch_tensorrt.Input._ShapeMode): Is input statically or dynamically shaped
3232
shape = None #: (Tuple or Dict): Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
33-
dtype = _enums.dtype.unknown #: The expected data type of the input tensor (default: trtorch.dtype.float32)
33+
dtype = _enums.dtype.unknown #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
3434
_explicit_set_dtype = False
35-
format = _enums.TensorFormat.contiguous #: The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
35+
format = _enums.TensorFormat.contiguous #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
3636

3737
def __init__(self, *args, **kwargs):
38-
""" __init__ Method for trtorch.Input
38+
""" __init__ Method for torch_tensorrt.Input
3939
4040
Input accepts one of a few construction patterns
4141
@@ -50,13 +50,13 @@ def __init__(self, *args, **kwargs):
5050
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
5151
max_shape (Tuple or List, optional): Max size of input tensor's shape range
5252
Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC
53-
dtype (torch.dtype or trtorch.dtype): Expected data type for input tensor (default: trtorch.dtype.float32)
54-
format (torch.memory_format or trtorch.TensorFormat): The expected format of the input tensor (default: trtorch.TensorFormat.NCHW)
53+
dtype (torch.dtype or torch_tensorrt.dtype): Expected data type for input tensor (default: torch_tensorrt.dtype.float32)
54+
format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
5555
5656
Examples:
5757
- Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last)
58-
- Input(shape=(1,3,32,32), dtype=trtorch.dtype.int32, format=trtorch.TensorFormat.NCHW)
59-
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=trtorch.dtype.float32, format=trtorch.TensorFormat.NCHW
58+
- Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
59+
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
6060
"""
6161
if len(args) == 1:
6262
if not Input._supported_input_size_type(args[0]):
@@ -204,7 +204,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
204204
return dtype
205205

206206
else:
207-
raise TypeError("Input data type needs to be specified with a torch.dtype or a trtorch.dtype, got: " +
207+
raise TypeError("Input data type needs to be specified with a torch.dtype or a torch_tensorrt.dtype, got: " +
208208
str(type(dtype)))
209209

210210
@staticmethod
@@ -223,7 +223,7 @@ def _parse_format(format: Any) -> _enums.TensorFormat:
223223

224224
else:
225225
raise TypeError(
226-
"Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")
226+
"Tensor format needs to be specified with either torch.memory_format or torch_tensorrt.TensorFormat")
227227

228228
@classmethod
229229
def _from_tensor(cls, t: torch.Tensor):

py/torch_tensorrt/_util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22
from torch_tensorrt import _C
33

44
def dump_build_info():
5-
"""Prints build information about the TRTorch distribution to stdout
5+
"""Prints build information about the torch_tensorrt distribution to stdout
66
"""
77
print(get_build_info())
88

99

1010
def get_build_info() -> str:
11-
"""Returns a string containing the build information of TRTorch distribution
11+
"""Returns a string containing the build information of torch_tensorrt distribution
1212
1313
Returns:
14-
str: String containing the build information for TRTorch distribution
14+
str: String containing the build information for torch_tensorrt distribution
1515
"""
1616
build_info = _C.get_build_info()
17-
build_info = "TRTorch Version: " + str(__version__) + '\n' + build_info
17+
build_info = "Torch-TensorRT Version: " + str(__version__) + '\n' + build_info
1818
return build_info
1919

2020

py/torch_tensorrt/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.0a0+808e603f"
1+
__version__ = "1.0.0a0+483ef591"

py/torch_tensorrt/csrc/register_tensorrt_classes.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ void RegisterTRTCompileSpec() {
1919
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, dtype);
2020
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, format);
2121
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, input_is_dynamic);
22+
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::Input, explicit_set_dtype);
23+
2224

2325
static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
2426
.def(torch::init<>())

py/torch_tensorrt/csrc/tensorrt_backend.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace backend {
1212

1313
c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::GenericDict method_compile_spec) {
1414
auto mod = mod_val.toModule();
15+
1516
auto spec = c10::impl::toTypedDict<std::string, at::IValue>(method_compile_spec);
1617

1718
auto handles = c10::impl::GenericDict(
@@ -22,17 +23,12 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl::
2223
const auto& method_name = it->key();
2324
auto raw_spec = it->value().toCustomClass<trtorch::pyapi::CompileSpec>();
2425
LOG_DEBUG(raw_spec->stringify());
25-
auto cfg = raw_spec->toInternalCompileSpec();
26-
auto graph_and_ivals = core::lowering::Lower(mod_, method_name, cfg.lower_info);
27-
28-
auto g = graph_and_ivals.first;
29-
auto params = graph_and_ivals.second;
30-
auto named_params = core::ir::get_static_params(g->inputs(), params);
3126

27+
auto cfg = raw_spec->toInternalCompileSpec();
3228
auto convert_cfg = std::move(cfg.convert_info);
3329
auto device_spec = convert_cfg.engine_settings.device;
3430
auto device = core::runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
35-
auto serialized_engine = core::conversion::ConvertBlockToEngine(g->block(), convert_cfg, named_params);
31+
auto serialized_engine = core::ConvertGraphToTRTEngine(mod_, method_name, cfg);
3632
auto engine_handle = c10::make_intrusive<core::runtime::TRTEngine>(it->key(), serialized_engine, device);
3733
handles.insert(method_name, at::IValue(engine_handle));
3834
}

py/torch_tensorrt/csrc/tensorrt_classes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct Input : torch::CustomClassHolder {
5050
ADD_FIELD_GET_SET(max, std::vector<int64_t>);
5151
ADD_FIELD_GET_SET(input_is_dynamic, bool);
5252
ADD_FIELD_GET_SET(explicit_set_dtype, bool);
53-
ADD_ENUM_GET_SET(dtype, DataType, static_cast<int64_t>(DataType::kBool));
53+
ADD_ENUM_GET_SET(dtype, DataType, static_cast<int64_t>(DataType::kUnknown));
5454
ADD_ENUM_GET_SET(format, TensorFormat, static_cast<int64_t>(TensorFormat::kContiguous));
5555

5656
core::ir::Input toInternalInput();

0 commit comments

Comments
 (0)