|
| 1 | +import base64 |
| 2 | +from collections import defaultdict |
| 3 | +from typing import Any, List |
| 4 | + |
| 5 | +import torch |
| 6 | +from torch_tensorrt.dynamo.utils import input_is_dynamic, unwrap_tensor_shape |
| 7 | + |
| 8 | + |
| 9 | +@torch.library.register_fake("tensorrt::execute_engine") # type: ignore |
| 10 | +def fake_tensorrt_execute_engine( |
| 11 | + inputs: List[torch.Tensor], fake_trt_engine: Any |
| 12 | +) -> Any: |
| 13 | + """ |
| 14 | + We infer outputs using the TRT engine and inputs and return fake tensors in this meta kernel. |
| 15 | + """ |
| 16 | + # Here's what we are doing |
| 17 | + # 1) Check if inputs are dynamic (they have sym ints in their shapes) |
| 18 | + # 2) For dynamic inputs, we gather min_input_shape and max_input shape for all inputs |
| 19 | + # 3) For the above min and max input shape, capture the corresponding min and max output shape using TensorRT's set/get shapes mechanism |
| 20 | + # 4) Create a new symbolic fake tensor using min and max output shape for each output and return them |
| 21 | + # 5) For static inputs, the output shape will be static and we won't need to create sym ints |
| 22 | + is_dynamic_execution = input_is_dynamic(inputs) |
| 23 | + if is_dynamic_execution: |
| 24 | + modes = ["min", "max", "opt"] |
| 25 | + else: |
| 26 | + modes = ["opt"] |
| 27 | + |
| 28 | + # Get the TRTEngine class and infer output shapes based on input shapes |
| 29 | + trt_engine = fake_trt_engine.wrapped_obj.engine |
| 30 | + outputs_mode_dict = defaultdict(list) |
| 31 | + for mode in modes: |
| 32 | + input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs] |
| 33 | + proxy_outputs = trt_engine.infer_outputs(input_shapes) |
| 34 | + outputs_mode_dict[mode].extend(proxy_outputs) |
| 35 | + |
| 36 | + # Store the number of outputs |
| 37 | + if {"min", "max"}.issubset(outputs_mode_dict): |
| 38 | + assert len(outputs_mode_dict["min"]) == len(outputs_mode_dict["max"]) |
| 39 | + num_outputs = len(outputs_mode_dict["min"]) |
| 40 | + elif "opt" in outputs_mode_dict: |
| 41 | + num_outputs = len(outputs_mode_dict["opt"]) |
| 42 | + |
| 43 | + fake_outputs = [] |
| 44 | + for out_idx in range(num_outputs): |
| 45 | + output_shape = [] |
| 46 | + if is_dynamic_execution: |
| 47 | + # Create output symbolic shape using unbacked symint. |
| 48 | + # Note: We can't establish a relationship b/w incoming input symbolic shape (eg: s0) |
| 49 | + # and TensorRT's output shape (represented as unbacked u0). This situation doesn't seem |
| 50 | + # to affect compilation results / serialization during our testing. |
| 51 | + output_min_shape = outputs_mode_dict["min"][out_idx].size() |
| 52 | + output_opt_shape = outputs_mode_dict["opt"][out_idx].size() |
| 53 | + output_max_shape = outputs_mode_dict["max"][out_idx].size() |
| 54 | + |
| 55 | + ctx = torch._custom_ops.get_ctx() |
| 56 | + for min_val, opt_val, max_val in zip( |
| 57 | + output_min_shape, output_opt_shape, output_max_shape |
| 58 | + ): |
| 59 | + if min_val != max_val: |
| 60 | + output_sym_int = ctx.new_dynamic_size(min=min_val, max=max_val) |
| 61 | + # Update var to val (hint) |
| 62 | + output_sym_int_shape_env = output_sym_int.node.shape_env |
| 63 | + output_sym_int_shape_env.add_var_to_val( |
| 64 | + output_sym_int.node.expr, opt_val |
| 65 | + ) |
| 66 | + output_shape.append(output_sym_int) |
| 67 | + else: |
| 68 | + output_shape.append(min_val) |
| 69 | + else: |
| 70 | + output_shape.extend(outputs_mode_dict["opt"][out_idx].size()) |
| 71 | + |
| 72 | + fake_outputs.append( |
| 73 | + torch.empty(output_shape, dtype=outputs_mode_dict["opt"][out_idx].dtype) |
| 74 | + ) |
| 75 | + |
| 76 | + return fake_outputs |
| 77 | + |
| 78 | + |
| 79 | +@torch._library.register_fake_class("tensorrt::Engine") |
| 80 | +class FakeTRTEngine: |
| 81 | + def __init__(self, engine_info: List[str]) -> None: |
| 82 | + self.engine = torch.classes.tensorrt.Engine(engine_info) |
| 83 | + |
| 84 | + @classmethod |
| 85 | + def __obj_unflatten__(cls, flattened_tq: Any) -> Any: |
| 86 | + engine_idx = torch.ops.tensorrt.ENGINE_IDX() |
| 87 | + engine_info = [info[1] for info in flattened_tq] |
| 88 | + engine_info[engine_idx] = base64.b64decode(engine_info[engine_idx]) |
| 89 | + |
| 90 | + return cls(engine_info) |
| 91 | + |
| 92 | + def enable_profiling(self) -> Any: |
| 93 | + pass |
| 94 | + |
| 95 | + def disable_profiling(self) -> Any: |
| 96 | + pass |
| 97 | + |
| 98 | + def dump_engine_layer_info_to_file(self, path: str) -> Any: |
| 99 | + pass |
| 100 | + |
| 101 | + def dump_engine_layer_info(self) -> Any: |
| 102 | + pass |
| 103 | + |
| 104 | + def get_engine_layer_info(self) -> Any: |
| 105 | + pass |
| 106 | + |
| 107 | + def profile_path_prefix_getter(self) -> Any: |
| 108 | + pass |
| 109 | + |
| 110 | + def profile_path_prefix_setter(self) -> Any: |
| 111 | + pass |
| 112 | + |
| 113 | + def device_memory_budget_getter(self) -> Any: |
| 114 | + pass |
| 115 | + |
| 116 | + def device_memory_budget_setter(self) -> Any: |
| 117 | + pass |
| 118 | + |
| 119 | + def streamable_device_memory_budget_getter(self) -> Any: |
| 120 | + pass |
| 121 | + |
| 122 | + def automatic_device_memory_budget_getter(self) -> Any: |
| 123 | + pass |
| 124 | + |
| 125 | + def infer_outputs(self, input_shapes: List[Any]) -> Any: |
| 126 | + pass |
| 127 | + |
| 128 | + def __setstate__(self, serialized_state: List[str]) -> Any: |
| 129 | + pass |
0 commit comments