|
| 1 | +From d9b35495da58038fd5045cc0e2c1f0416f8e62f0 Mon Sep 17 00:00:00 2001 |
| 2 | +From: Chao Zhang < [email protected]> |
| 3 | +Date: Tue, 21 Jun 2022 15:38:23 +0000 |
| 4 | +Subject: [PATCH] Fix getitem for Py<3.7 |
| 5 | + |
| 6 | +--- |
| 7 | + torch2trt/torch2trt.py | 13 ++++++++++++- |
| 8 | + 1 file changed, 12 insertions(+), 1 deletion(-) |
| 9 | + |
| 10 | +diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py |
| 11 | +index 3aa6946..9528f1a 100644 |
| 12 | +--- a/torch2trt/torch2trt.py |
| 13 | ++++ b/torch2trt/torch2trt.py |
| 14 | +@@ -310,6 +310,14 @@ def attach_converter(ctx, method, converter, method_str): |
| 15 | + return wrapper |
| 16 | + |
| 17 | + |
| 18 | ++def _getitem_wrapper(method=torch.Tensor.__getitem__): |
| 19 | ++ def wrapper(arg0, arg1): |
| 20 | ++ if type(arg1) is torch.Tensor: |
| 21 | ++ arg1 = (arg1, ) |
| 22 | ++ return method(arg0, arg1) |
| 23 | ++ return wrapper |
| 24 | ++ |
| 25 | ++ |
| 26 | + class ConversionHook(object): |
| 27 | + """Attaches TensorRT converter to PyTorch method call""" |
| 28 | + |
| 29 | +@@ -330,7 +338,10 @@ class ConversionHook(object): |
| 30 | + ) |
| 31 | + |
| 32 | + def __exit__(self, type, val, tb): |
| 33 | +- self._set_method(self.converter['method_impl']) |
| 34 | ++ if '__getitem__' in self.converter['method_str']: |
| 35 | ++ self._set_method(_getitem_wrapper()) |
| 36 | ++ else: |
| 37 | ++ self._set_method(self.converter['method_impl']) |
| 38 | + |
| 39 | + def default_input_names(num_inputs): |
| 40 | + return ["input_%d" % i for i in range(num_inputs)] |
| 41 | +-- |
| 42 | +2.32.0 |
| 43 | + |
0 commit comments