@@ -26,7 +26,7 @@ def fake_tensorrt_execute_engine(
2626 modes = ["opt" ]
2727
2828 # Get the TRTEngine class and infer output shapes based on input shapes
29- trt_engine = fake_trt_engine .wrapped_obj . engine
29+ trt_engine = fake_trt_engine .real_obj
3030 outputs_mode_dict = defaultdict (list )
3131 for mode in modes :
3232 input_shapes = [unwrap_tensor_shape (input , mode = mode ) for input in inputs ]
@@ -79,7 +79,21 @@ def fake_tensorrt_execute_engine(
7979@torch ._library .register_fake_class ("tensorrt::Engine" )
8080class FakeTRTEngine :
8181 def __init__ (self , engine_info : List [str ]) -> None :
82- self .engine = torch .classes .tensorrt .Engine (engine_info )
82+ self .version = engine_info [torch .ops .tensorrt .ABI_TARGET_IDX ()]
83+ self .name = engine_info [torch .ops .tensorrt .NAME_IDX ()]
84+ self .device_info = engine_info [torch .ops .tensorrt .DEVICE_IDX ()]
85+ self .serialized_engine = engine_info [torch .ops .tensorrt .ENGINE_IDX ()]
86+ self .in_binding_names = engine_info [
87+ torch .ops .tensorrt .INPUT_BINDING_NAMES_IDX ()
88+ ]
89+ self .out_binding_names = engine_info [
90+ torch .ops .tensorrt .OUTPUT_BINDING_NAMES_IDX ()
91+ ]
92+ self .hardware_compatible = engine_info [torch .ops .tensorrt .HW_COMPATIBLE_IDX ()]
93+ self .serialized_metadata = engine_info [
94+ torch .ops .tensorrt .SERIALIZED_METADATA_IDX ()
95+ ]
96+ self .target_platform = engine_info [torch .ops .tensorrt .TARGET_PLATFORM_IDX ()]
8397
8498 @classmethod
8599 def __obj_unflatten__ (cls , flattened_tq : Any ) -> Any :
@@ -127,3 +141,6 @@ def infer_outputs(self, input_shapes: List[Any]) -> Any:
127141
128142 def __setstate__ (self , serialized_state : List [str ]) -> Any :
129143 pass
144+
145+ def __getstate__ (self ) -> Any :
146+ pass
0 commit comments