Skip to content

Commit 5031324

Browse files
committed
refactor: Addressing PR comments
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0815680 commit 5031324

File tree

4 files changed

+9
-7
lines changed

4 files changed

+9
-7
lines changed

docsrc/tutorials/use_from_pytorch.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
4242
"device": {
4343
"device_type": trtorch.DeviceType.GPU,
4444
"gpu_id": 0,
45+
"dla_core": 0,
4546
"allow_gpu_fallback": True
4647
},
4748
"capability": trtorch.EngineCapability.default,

py/trtorch/_compile_spec.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,10 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
195195
} # Dynamic input shape for input #2
196196
],
197197
"device": {
198-
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
199-
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
200-
"dla_core": 0, # (DLA only) Target dla core id to run engine
201-
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
198+
"device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
199+
"gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA)
200+
"dla_core": 0, # (DLA only) Target dla core id to run engine
201+
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
202202
},
203203
"op_precision": torch.half, # Operating precision set to FP16
204204
"refit": False, # enable refit

py/trtorch/csrc/tensorrt_classes.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct Device : torch::CustomClassHolder {
6969
allow_gpu_fallback(false) // allow_gpu_fallback
7070
{}
7171

72-
ADD_ENUM_GET_SET(device_type, DeviceType, 1);
72+
ADD_ENUM_GET_SET(device_type, DeviceType, static_cast<int64_t>(DeviceType::kDLA));
7373
ADD_FIELD_GET_SET(gpu_id, int64_t);
7474
ADD_FIELD_GET_SET(dla_core, int64_t);
7575
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
@@ -98,11 +98,11 @@ struct CompileSpec : torch::CustomClassHolder {
9898
device = *d;
9999
}
100100

101-
ADD_ENUM_GET_SET(op_precision, DataType, 2);
101+
ADD_ENUM_GET_SET(op_precision, DataType, static_cast<int64_t>(DataType::kChar));
102102
ADD_FIELD_GET_SET(refit, bool);
103103
ADD_FIELD_GET_SET(debug, bool);
104104
ADD_FIELD_GET_SET(strict_types, bool);
105-
ADD_ENUM_GET_SET(capability, EngineCapability, 2);
105+
ADD_ENUM_GET_SET(capability, EngineCapability, static_cast<int64_t>(EngineCapability::kSAFE_DLA));
106106
ADD_FIELD_GET_SET(num_min_timing_iters, int64_t);
107107
ADD_FIELD_GET_SET(num_avg_timing_iters, int64_t);
108108
ADD_FIELD_GET_SET(workspace_size, int64_t);

tests/py/test_to_backend_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def setUp(self):
2222
"device": {
2323
"device_type": trtorch.DeviceType.GPU,
2424
"gpu_id": 0,
25+
"dla_core": 0,
2526
"allow_gpu_fallback": True
2627
},
2728
"capability": trtorch.EngineCapability.default,

0 commit comments

Comments
 (0)