|
2 | 2 | from cuda import cudart |
3 | 3 | import torch |
4 | 4 | import tensorrt as trt |
5 | | - |
| 5 | +import subprocess |
| 6 | +from collections import defaultdict |
6 | 7 | from collections import OrderedDict |
7 | 8 | from polygraphy.backend.common import bytes_from_path |
8 | 9 | from polygraphy.backend.trt import engine_from_bytes |
@@ -78,6 +79,124 @@ def check_dims(self, batch_size, image_height, image_width, compression_factor = |
78 | 79 | assert latent_width >= min_latent_shape and latent_width <= max_latent_shape |
79 | 80 | return (latent_height, latent_width) |
80 | 81 |
|
| 82 | + def build( |
| 83 | + self, |
| 84 | + onnx_path, |
| 85 | + strongly_typed=False, |
| 86 | + fp16=False, |
| 87 | + bf16=True, |
| 88 | + tf32=True, |
| 89 | + int8=False, |
| 90 | + fp8=False, |
| 91 | + input_profile=None, |
| 92 | + enable_refit=False, |
| 93 | + enable_all_tactics=False, |
| 94 | + timing_cache=None, |
| 95 | + update_output_names=None, |
| 96 | + native_instancenorm=True, |
| 97 | + verbose=False, |
| 98 | + weight_streaming=False, |
| 99 | + builder_optimization_level=3, |
| 100 | + precision_constraints='none', |
| 101 | + ): |
| 102 | + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") |
| 103 | + |
| 104 | + # Handle weight streaming case: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#streaming-weights. |
| 105 | + if weight_streaming: |
| 106 | + strongly_typed, fp16, bf16, int8, fp8 = True, False, False, False, False |
| 107 | + |
| 108 | + # Base command |
| 109 | + build_command = [f"polygraphy convert {onnx_path} --convert-to trt --output {self.engine_path}"] |
| 110 | + |
| 111 | + # Precision flags |
| 112 | + build_args = [ |
| 113 | + "--fp16" if fp16 else "", |
| 114 | + "--bf16" if bf16 else "", |
| 115 | + "--tf32" if tf32 else "", |
| 116 | + "--fp8" if fp8 else "", |
| 117 | + "--int8" if int8 else "", |
| 118 | + "--strongly-typed" if strongly_typed else "", |
| 119 | + ] |
| 120 | + |
| 121 | + # Additional arguments |
| 122 | + build_args.extend([ |
| 123 | + "--weight-streaming" if weight_streaming else "", |
| 124 | + "--refittable" if enable_refit else "", |
| 125 | + "--tactic-sources" if not enable_all_tactics else "", |
| 126 | + "--onnx-flags native_instancenorm" if native_instancenorm else "", |
| 127 | + f"--builder-optimization-level {builder_optimization_level}", |
| 128 | + f"--precision-constraints {precision_constraints}", |
| 129 | + ]) |
| 130 | + |
| 131 | + # Timing cache |
| 132 | + if timing_cache: |
| 133 | + build_args.extend([ |
| 134 | + f"--load-timing-cache {timing_cache}", |
| 135 | + f"--save-timing-cache {timing_cache}" |
| 136 | + ]) |
| 137 | + |
| 138 | + # Verbosity setting |
| 139 | + verbosity = "extra_verbose" if verbose else "error" |
| 140 | + build_args.append(f"--verbosity {verbosity}") |
| 141 | + |
| 142 | + # Output names |
| 143 | + if update_output_names: |
| 144 | + print(f"Updating network outputs to {update_output_names}") |
| 145 | + # build_args.append(f"--trt-outputs {' '.join(update_output_names)}") |
| 146 | + build_args.append(f"--trt-outputs {update_output_names}") |
| 147 | + |
| 148 | + # Input profiles |
| 149 | + if input_profile: |
| 150 | + profile_args = defaultdict(str) |
| 151 | + for name, dims in input_profile.items(): |
| 152 | + assert len(dims) == 3 |
| 153 | + profile_args["--trt-min-shapes"] += f"{name}:{str(list(dims[0])).replace(' ', '')} " |
| 154 | + profile_args["--trt-opt-shapes"] += f"{name}:{str(list(dims[1])).replace(' ', '')} " |
| 155 | + profile_args["--trt-max-shapes"] += f"{name}:{str(list(dims[2])).replace(' ', '')} " |
| 156 | + |
| 157 | + build_args.extend(f"{k} {v}" for k, v in profile_args.items()) |
| 158 | + |
| 159 | + # Filter out empty strings and join command |
| 160 | + build_args = [arg for arg in build_args if arg] |
| 161 | + final_command = ' '.join(build_command + build_args) |
| 162 | + |
| 163 | + # Execute command with improved error handling |
| 164 | + try: |
| 165 | + print(f"Engine build command: {final_command}") |
| 166 | + subprocess.run(final_command, check=True, shell=True) |
| 167 | + except subprocess.CalledProcessError as exc: |
| 168 | + error_msg = ( |
| 169 | + f"Failed to build TensorRT engine. Error details:\n" |
| 170 | + f"Command: {exc.cmd}\n" |
| 171 | + ) |
| 172 | + raise RuntimeError(error_msg) from exc |
| 173 | + |
| 174 | + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_shape, compression_factor=8, min_batch=1, max_batch=8, min_image_shape=256, max_image_shape=1344, min_latent_shape=16, max_latent_shape=1024): |
| 175 | + min_batch = batch_size if static_batch else self.min_batch |
| 176 | + max_batch = batch_size if static_batch else self.max_batch |
| 177 | + latent_height = image_height // compression_factor |
| 178 | + latent_width = image_width // compression_factor |
| 179 | + min_image_height = image_height if static_shape else min_image_shape |
| 180 | + max_image_height = image_height if static_shape else max_image_shape |
| 181 | + min_image_width = image_width if static_shape else min_image_shape |
| 182 | + max_image_width = image_width if static_shape else max_image_shape |
| 183 | + min_latent_height = latent_height if static_shape else min_latent_shape |
| 184 | + max_latent_height = latent_height if static_shape else max_latent_shape |
| 185 | + min_latent_width = latent_width if static_shape else min_latent_shape |
| 186 | + max_latent_width = latent_width if static_shape else max_latent_shape |
| 187 | + return ( |
| 188 | + min_batch, |
| 189 | + max_batch, |
| 190 | + min_image_height, |
| 191 | + max_image_height, |
| 192 | + min_image_width, |
| 193 | + max_image_width, |
| 194 | + min_latent_height, |
| 195 | + max_latent_height, |
| 196 | + min_latent_width, |
| 197 | + max_latent_width, |
| 198 | + ) |
| 199 | + |
81 | 200 |
|
82 | 201 |
|
83 | 202 |
|
0 commit comments