|
| 1 | +import os |
| 2 | +import argparse |
| 3 | +import tensorrt as trt |
| 4 | + |
| 5 | +def main(onnx_path, engine_path, max_batchsize, opt_batchsize, min_batchsize, use_fp16=True, verbose=False)->None: |
| 6 | + """ Convert ONNX model to TensorRT engine. |
| 7 | + Args: |
| 8 | + onnx_path (str): Path to the input ONNX model. |
| 9 | + engine_path (str): Path to save the output TensorRT engine. |
| 10 | + use_fp16 (bool): Whether to use FP16 precision. |
| 11 | + verbose (bool): Whether to enable verbose logging. |
| 12 | + """ |
| 13 | + logger = trt.Logger(trt.Logger.VERBOSE if verbose else trt.Logger.INFO) |
| 14 | + |
| 15 | + builder = trt.Builder(logger) |
| 16 | + network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| 17 | + network = builder.create_network(network_flags) |
| 18 | + |
| 19 | + parser = trt.OnnxParser(network, logger) |
| 20 | + config = builder.create_builder_config() |
| 21 | + config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, True) |
| 22 | + |
| 23 | + if not os.path.isfile(onnx_path): |
| 24 | + raise FileNotFoundError(f"ONNX file not found: {onnx_path}") |
| 25 | + |
| 26 | + print(f"[INFO] Loading ONNX file from {onnx_path}") |
| 27 | + with open(onnx_path, "rb") as f: |
| 28 | + if not parser.parse(f.read()): |
| 29 | + for error in range(parser.num_errors): |
| 30 | + print(parser.get_error(error)) |
| 31 | + raise RuntimeError("Failed to parse ONNX file") |
| 32 | + |
| 33 | + config = builder.create_builder_config() |
| 34 | + config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, True) |
| 35 | + config.max_workspace_size = 1 << 30 # 1GB |
| 36 | + |
| 37 | + if use_fp16: |
| 38 | + if builder.platform_has_fast_fp16: |
| 39 | + config.set_flag(trt.BuilderFlag.FP16) |
| 40 | + print("[INFO] FP16 optimization enabled.") |
| 41 | + else: |
| 42 | + print("[WARNING] FP16 not supported on this platform. Proceeding with FP32.") |
| 43 | + |
| 44 | + profile = builder.create_optimization_profile() |
| 45 | + profile.set_shape("images", min=(min_batchsize, 3, 640, 640), opt=(opt_batchsize, 3, 640, 640), max=(max_batchsize, 3, 640, 640)) |
| 46 | + profile.set_shape("orig_target_sizes", min=(1, 2), opt=(1, 2), max=(1, 2)) |
| 47 | + config.add_optimization_profile(profile) |
| 48 | + |
| 49 | + print("[INFO] Building TensorRT engine...") |
| 50 | + engine = builder.build_engine(network, config) |
| 51 | + |
| 52 | + if engine is None: |
| 53 | + raise RuntimeError("Failed to build the engine.") |
| 54 | + |
| 55 | + print(f"[INFO] Saving engine to {engine_path}") |
| 56 | + with open(engine_path, "wb") as f: |
| 57 | + f.write(engine.serialize()) |
| 58 | + print("[INFO] Engine export complete.") |
| 59 | + |
| 60 | + |
| 61 | +if __name__ == "__main__": |
| 62 | + parser = argparse.ArgumentParser(description="Convert ONNX to TensorRT Engine") |
| 63 | + parser.add_argument("--onnx", "-i", type=str, required=True, help="Path to input ONNX model file") |
| 64 | + parser.add_argument("--saveEngine", "-o", type=str, default="model.engine", help="Path to output TensorRT engine file") |
| 65 | + parser.add_argument("--maxBatchSize", "-Mb", type=int, default=32, help="Maximum batch size for inference") |
| 66 | + parser.add_argument("--optBatchSize", "-ob", type=int, default=16, help="Optimal batch size for inference") |
| 67 | + parser.add_argument("--minBatchSize", "-mb", type=int, default=1, help="Minimum batch size for inference") |
| 68 | + parser.add_argument("--fp16", default=True, action="store_true", help="Enable FP16 precision mode") |
| 69 | + parser.add_argument("--verbose", action="store_true", help="Enable verbose logging") |
| 70 | + |
| 71 | + args = parser.parse_args() |
| 72 | + |
| 73 | + main( |
| 74 | + onnx_path=args.onnx, |
| 75 | + engine_path=args.saveEngine, |
| 76 | + max_batchsize=args.maxBatchSize, |
| 77 | + opt_batchsize=args.optBatchSize, |
| 78 | + min_batchsize=args.minBatchSize, |
| 79 | + use_fp16=args.fp16, |
| 80 | + verbose=args.verbose |
| 81 | + ) |
0 commit comments