Skip to content

Commit 90be642

Browse files
ArynthonArin Toaca
authored andcommitted
Add fp16 support for python backend.
1 parent 8716c9b commit 90be642

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

onnx_tensorrt/backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def count_trailing_ones(vals):
5858

5959

6060
class TensorRTBackendRep(BackendRep):
61-
def __init__(self, model, device, max_batch_size=32,
61+
def __init__(self, model, device, max_batch_size=32, fp16_mode=False,
6262
max_workspace_size=None, serialize_engine=False, **kwargs):
6363
if not isinstance(device, Device):
6464
device = Device(device)
@@ -89,6 +89,7 @@ def __init__(self, model, device, max_batch_size=32,
8989

9090
self.builder.max_batch_size = max_batch_size
9191
self.builder.max_workspace_size = max_workspace_size
92+
self.builder.fp16_mode = fp16_mode
9293

9394
for layer in self.network:
9495
print(layer.name)
@@ -231,4 +232,4 @@ def supports_device(cls, device_str):
231232
prepare = TensorRTBackend.prepare
232233
run_node = TensorRTBackend.run_node
233234
run_model = TensorRTBackend.run_model
234-
supports_device = TensorRTBackend.supports_device
235+
supports_device = TensorRTBackend.supports_device

0 commit comments

Comments
 (0)