diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 4ba7e0faeb..26bb447b4f 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -152,10 +152,10 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: "12.4": "pytorch/manylinux2_28-builder:cuda12.4", "12.6": "pytorch/manylinux2_28-builder:cuda12.6", **{ - gpu_arch: f"pytorch/manylinux-builder:rocm{gpu_arch}" + gpu_arch: f"pytorch/manylinux2_28-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES }, - CPU: "pytorch/manylinux-builder:cpu", + CPU: "pytorch/manylinux2_28-builder:cpu", XPU: "pytorch/manylinux2_28-builder:xpu", # TODO: Migrate CUDA_AARCH64 image to manylinux2_28_aarch64-builder:cuda12.4 CPU_AARCH64: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64", @@ -163,7 +163,7 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: } LIBTORCH_CONTAINER_IMAGES = { **{ - (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux-builder:cuda{gpu_arch}" + (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux2_28-builder:cuda{gpu_arch}" for gpu_arch in CUDA_ARCHES }, **{ @@ -171,14 +171,14 @@ def initialize_globals(channel: str, build_python_only: bool) -> None: for gpu_arch in CUDA_ARCHES }, **{ - (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux-builder:rocm{gpu_arch}" + (gpu_arch, PRE_CXX11_ABI): f"pytorch/manylinux2_28-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES }, **{ (gpu_arch, CXX11_ABI): f"pytorch/libtorch-cxx11-builder:rocm{gpu_arch}" for gpu_arch in ROCM_ARCHES }, - (CPU, PRE_CXX11_ABI): "pytorch/manylinux-builder:cpu", + (CPU, PRE_CXX11_ABI): "pytorch/manylinux2_28-builder:cpu", (CPU, CXX11_ABI): "pytorch/libtorch-cxx11-builder:cpu", } diff --git a/.github/workflows/build-test-linux.yml b/.github/workflows/build-test-linux.yml index 72d7e21b5c..b0a487bb79 100644 --- a/.github/workflows/build-test-linux.yml +++ b/.github/workflows/build-test-linux.yml @@ -137,7 +137,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: diff --git a/.github/workflows/build-test-tensorrt-linux.yml b/.github/workflows/build-test-tensorrt-linux.yml index cfad7274dc..625ffe9a31 100644 --- a/.github/workflows/build-test-tensorrt-linux.yml +++ b/.github/workflows/build-test-tensorrt-linux.yml @@ -129,7 +129,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: @@ -314,4 +314,4 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} - cancel-in-progress: true \ No newline at end of file + cancel-in-progress: true diff --git a/.github/workflows/build-test-tensorrt-windows.yml b/.github/workflows/build-test-tensorrt-windows.yml index d2be9febd7..fe812e1b9d 100644 --- a/.github/workflows/build-test-tensorrt-windows.yml +++ b/.github/workflows/build-test-tensorrt-windows.yml @@ -132,7 +132,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: @@ -298,4 +298,4 @@ jobs: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref_name }}-${{ inputs.repository }}-${{ github.event_name == 'workflow_dispatch' }}-${{ inputs.job-name }} - cancel-in-progress: true \ No newline at end of file + cancel-in-progress: true diff --git a/.github/workflows/build-test-windows.yml b/.github/workflows/build-test-windows.yml index c2b05d8994..c227d14a0f 100644 --- a/.github/workflows/build-test-windows.yml +++ b/.github/workflows/build-test-windows.yml @@ -119,7 +119,7 @@ jobs: export CI_BUILD=1 pushd . cd tests/py/dynamo - python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/ + python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/ popd tests-py-dynamo-fe: diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 989913bd31..1148d4f792 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -52,7 +52,7 @@ def compile_bert(iterations=3): "truncate_double": True, "debug": False, "min_block_size": 1, - "make_refittable": True, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 28ff73aa72..fb4c341077 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): # in a subsequent compilation, either as part of this session or a new session, the cache will # pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude. # As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``), -# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details. +# the engine must be refittable (``immutable_weights=False``). See :ref:`refit_engine_example` for more details. def torch_compile(iterations=3): @@ -97,7 +97,7 @@ def torch_compile(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, - "make_refittable": True, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, }, @@ -157,7 +157,7 @@ def dynamo_compile(iterations=3): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_size=1 << 30, # 1GB @@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, - "make_refittable": True, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": engine_cache, diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index b68c9a11ee..8b62855c32 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -31,7 +31,7 @@ settings = { "use_python": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -80,7 +80,7 @@ "use_python_runtime": True, "enabled_precisions": {torch.float16}, "debug": True, - "make_refittable": True, + "immutable_weights": False, } model_id = "runwayml/stable-diffusion-v1-5" diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index f93b097385..66a1a70964 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -47,7 +47,7 @@ # --------------------------------------- # # The inital step is to compile a module and save it as with a normal. Note that there is an -# additional parameter `make_refittable` that is set to `True`. This parameter is used to +# additional parameter `immutable_weights` that is set to `False`. This parameter is used to # indicate that the engine being built should support weight refitting later. Engines built without # these setttings will not be able to be refit. # @@ -69,7 +69,7 @@ debug=debug, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, - make_refittable=True, + immutable_weights=False, reuse_cached_engines=False, ) # Output is a torch.fx.GraphModule diff --git a/py/ci/Dockerfile.ci b/py/ci/Dockerfile.ci index eddf12cefb..823c8bb7a1 100644 --- a/py/ci/Dockerfile.ci +++ b/py/ci/Dockerfile.ci @@ -1,4 +1,4 @@ -FROM pytorch/manylinux-builder:cuda12.4 +FROM pytorch/manylinux2_28-builder:cuda12.6 RUN yum install -y ninja-build diff --git a/py/torch_tensorrt/_enums.py b/py/torch_tensorrt/_enums.py index a580e6efbb..eaefb68ce5 100644 --- a/py/torch_tensorrt/_enums.py +++ b/py/torch_tensorrt/_enums.py @@ -220,7 +220,7 @@ def _from( return dtype.f32 elif t == np.float64: return dtype.f64 - elif t == np.bool: + elif t == np.bool_: return dtype.b # TODO: Consider using ml_dtypes when issues like this are resolved: # https://github.com/pytorch/pytorch/issues/109873 @@ -1384,7 +1384,7 @@ def current_platform(cls) -> Platform: def __str__(self) -> str: return str(self.name) - @needs_torch_tensorrt_runtime + @needs_torch_tensorrt_runtime # type: ignore def _to_serialized_rt_platform(self) -> str: val: str = torch.ops.tensorrt._platform_unknown() diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9859668cd9..88e66b0f3c 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -63,7 +63,6 @@ def cross_compile_for_windows( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refittable: bool = _defaults.MAKE_REFITTABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -93,6 +92,9 @@ def cross_compile_for_windows( custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, **kwargs: Any, ) -> torch.fx.GraphModule: @@ -132,7 +134,6 @@ def cross_compile_for_windows( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -164,6 +165,9 @@ def cross_compile_for_windows( custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. **kwargs: Any, Returns: @@ -193,14 +197,44 @@ def cross_compile_for_windows( if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", + "`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted.", DeprecationWarning, stacklevel=2, ) - if make_refittable: - raise ValueError("Use flag make_refittable only. Flag refit is deprecated.") + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) else: - make_refittable = kwargs["refit"] + immutable_weights = not kwargs["refit"] + + if "make_refittable" in kwargs.keys(): + warnings.warn( + "`make_refittable` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", + DeprecationWarning, + stacklevel=2, + ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `make_refittable` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) engine_capability = EngineCapability._from(engine_capability) @@ -275,7 +309,6 @@ def cross_compile_for_windows( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -286,6 +319,9 @@ def cross_compile_for_windows( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, + "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": True, "enable_weight_streaming": enable_weight_streaming, } @@ -342,7 +378,6 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refittable: bool = _defaults.MAKE_REFITTABLE, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -372,6 +407,9 @@ def compile( custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, **kwargs: Any, ) -> torch.fx.GraphModule: @@ -413,7 +451,6 @@ def compile( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -445,6 +482,9 @@ def compile( custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. **kwargs: Any, Returns: @@ -468,14 +508,44 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", + "`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", + DeprecationWarning, + stacklevel=2, + ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] + + if "make_refittable" in kwargs.keys(): + warnings.warn( + "`make_refittable` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) - if make_refittable: - raise ValueError("Use flag make_refittable only. Flag refit is deprecated.") + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `make_refittable` is deprecated." + ) else: - make_refittable = kwargs["refit"] + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) if ( "enable_cross_compile_for_windows" in kwargs.keys() @@ -541,9 +611,6 @@ def compile( engine_cache = None if cache_built_engines or reuse_cached_engines: - assert ( - make_refittable - ), "Engine caching requires make_refittable to be set to True" engine_cache = ( custom_engine_cache if custom_engine_cache is not None @@ -574,7 +641,6 @@ def compile( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -587,6 +653,9 @@ def compile( "reuse_cached_engines": reuse_cached_engines, "use_explicit_typing": use_explicit_typing, "use_fp32_acc": use_fp32_acc, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, + "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": False, "enable_weight_streaming": enable_weight_streaming, } @@ -861,7 +930,6 @@ def convert_exported_program_to_serialized_trt_engine( require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, disable_tf32: bool = _defaults.DISABLE_TF32, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - make_refittable: bool = _defaults.MAKE_REFITTABLE, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -872,6 +940,9 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path: str = _defaults.TIMING_CACHE_PATH, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, **kwargs: Any, ) -> bytes: @@ -922,7 +993,6 @@ def convert_exported_program_to_serialized_trt_engine( Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights - refit (bool): Whether to build a refittable engine engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. @@ -933,6 +1003,9 @@ def convert_exported_program_to_serialized_trt_engine( timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. Returns: bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs @@ -952,12 +1025,48 @@ def convert_exported_program_to_serialized_trt_engine( DeprecationWarning, stacklevel=2, ) + if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", + "`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", + DeprecationWarning, + stacklevel=2, + ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] + + if "make_refittable" in kwargs.keys(): + warnings.warn( + "`make_refittable` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `make_refittable` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) + if arg_inputs is None and inputs is None: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") @@ -1000,7 +1109,6 @@ def convert_exported_program_to_serialized_trt_engine( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, "engine_capability": engine_capability, "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, @@ -1009,6 +1117,9 @@ def convert_exported_program_to_serialized_trt_engine( "timing_cache_path": timing_cache_path, "use_explicit_typing": use_explicit_typing, "use_fp32_acc": use_fp32_acc, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, + "immutable_weights": immutable_weights, "enable_weight_streaming": enable_weight_streaming, } diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index f6b97b1fbb..76630a75a5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,6 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -MAKE_REFITTABLE = False REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False @@ -38,10 +37,13 @@ CACHE_BUILT_ENGINES = False REUSE_CACHED_ENGINES = False ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache") -ENGINE_CACHE_SIZE = 1073741824 +ENGINE_CACHE_SIZE = 5368709120 # 5GB CUSTOM_ENGINE_CACHE = None USE_EXPLICIT_TYPING = False USE_FP32_ACC = False +REFIT_IDENTICAL_ENGINE_WEIGHTS = False +STRIP_ENGINE_WEIGHTS = False +IMMUTABLE_WEIGHTS = True ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 519423e15d..f1041682f8 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -156,13 +156,26 @@ def _refit_single_trt_engine_with_gm( if torch_device.type == "cuda" else trt.TensorLocation.HOST ) + + constant_mapping: dict[str, Any] = weight_name_map.pop( + "constant_mapping", {} + ) # type: ignore mapping = construct_refit_mapping_from_weight_name_map( weight_name_map, new_gm.state_dict() ) + constant_mapping_with_type = {} + + for constant_name, val in constant_mapping.items(): + np_weight_type = val.dtype + val_tensor = torch.from_numpy(val).cuda() + trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType) + torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype) + constant_mapping_with_type[constant_name] = ( + val_tensor.clone().reshape(-1).contiguous().to(torch_dtype), + trt_dtype, + ) - # Debug Use - # correct = construct_refit_mapping(new_gm, input_list, settings) - # comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct} + mapping.update(constant_mapping_with_type) for layer_name in weight_list: if layer_name not in mapping: @@ -251,7 +264,7 @@ def refit_module_weights( ] assert ( encoded_metadata != "" - ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refittable=True" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version" settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] # Handle torch modules compiled_submodules_map = dict(compiled_submodules) @@ -269,8 +282,8 @@ def refit_module_weights( assert settings is not None assert ( - settings.make_refittable - ), "Refitting is not enabled. Please recompile the engine with refit=True." + not settings.immutable_weights + ), "Refitting is not enabled. Please recompile the engine with immutable_weights=False." if settings.debug: set_log_level(logger.parent, logging.DEBUG) @@ -449,17 +462,21 @@ def refit_module_weights( weight_name_map=None, ) - if isinstance(compiled_submodule, TorchTensorRTModule): - serialized_engine = bytes(engine.serialize()) - new_engine_info = list(engine_info) - new_engine_info[ENGINE_IDX] = serialized_engine - refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) - compiled_submodule.engine = refitted_engine + # clear EXCLUDE_WEIGHTS flag + serialization_config = engine.create_serialization_config() + serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + serialized_engine = engine.serialize_with_config(serialization_config) + + if isinstance( + compiled_submodule, (PythonTorchTensorRTModule, TorchTensorRTModule) + ): + compiled_submodule.engine = None # Clear the engine for TorchTensorRTModule, otherwise it won't be updated + compiled_submodule.serialized_engine = bytes(serialized_engine) + compiled_submodule.setup_engine() elif inline_module: - serialized_engine = bytes(engine.serialize()) new_engine_info = list(engine_info) - new_engine_info[ENGINE_IDX] = serialized_engine + new_engine_info[ENGINE_IDX] = bytes(serialized_engine) refitted_engine = torch.classes.tensorrt.Engine(tuple(new_engine_info)) setattr(compiled_module, f"{name}_engine", refitted_engine) diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 9062e2e539..7a22663af3 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -19,16 +19,18 @@ ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, + IMMUTABLE_WEIGHTS, LAZY_ENGINE_INIT, - MAKE_REFITTABLE, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, OPTIMIZATION_LEVEL, PASS_THROUGH_BUILD_FAILURES, + REFIT_IDENTICAL_ENGINE_WEIGHTS, REQUIRE_FULL_COMPILATION, REUSE_CACHED_ENGINES, SPARSE_WEIGHTS, + STRIP_ENGINE_WEIGHTS, TIMING_CACHE_PATH, TRUNCATE_DOUBLE, USE_EXPLICIT_TYPING, @@ -69,7 +71,6 @@ class CompilationSettings: assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False disable_tf32 (bool): Whether to disable TF32 computation for TRT layers sparse_weights (bool): Whether to allow the builder to use sparse weights - refit (bool): Whether to build a refittable engine engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer. @@ -84,6 +85,9 @@ class CompilationSettings: reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Whether to refit the engine with identical weights + strip_engine_weights (bool): Whether to strip the engine weights + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored enable_weight_streaming (bool): Enable weight streaming. enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows @@ -107,7 +111,6 @@ class CompilationSettings: disable_tf32: bool = DISABLE_TF32 assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT sparse_weights: bool = SPARSE_WEIGHTS - make_refittable: bool = MAKE_REFITTABLE engine_capability: EngineCapability = field( default_factory=lambda: ENGINE_CAPABILITY ) @@ -123,6 +126,9 @@ class CompilationSettings: reuse_cached_engines: bool = REUSE_CACHED_ENGINES use_explicit_typing: bool = USE_EXPLICIT_TYPING use_fp32_acc: bool = USE_FP32_ACC + refit_identical_engine_weights: bool = REFIT_IDENTICAL_ENGINE_WEIGHTS + strip_engine_weights: bool = STRIP_ENGINE_WEIGHTS + immutable_weights: bool = IMMUTABLE_WEIGHTS enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS @@ -134,9 +140,11 @@ class CompilationSettings: "optimization_level", "disable_tf32", "sparse_weights", - "make_refittable", "engine_capability", "hardware_compatible", + "refit_identical_engine_weights", + "strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default? + "immutable_weights", "enable_weight_streaming", ) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index e15ed0495f..c8a30e656b 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -111,6 +111,10 @@ def _pretraced_backend( logger.warning( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) + if settings.strip_engine_weights: + logger.error( + "strip_engine_weights arg is not supported for torch.compile()" + ) trt_compiled = compile_module( gm, torchtrt_inputs, diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 03852ae6ae..d7c0ea449e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -287,8 +287,21 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if self.compilation_settings.make_refittable: - builder_config.set_flag(trt.BuilderFlag.REFIT) + if self.compilation_settings.immutable_weights: + # non-refittable engine + if self.compilation_settings.strip_engine_weights: + _LOGGER.warning("strip_engine_weights will be ignored.") + if self.compilation_settings.refit_identical_engine_weights: + _LOGGER.warning("refit_identical_engine_weights will be ignored.") + else: + # refittable engine + if self.compilation_settings.refit_identical_engine_weights: + builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) + else: + builder_config.set_flag(trt.BuilderFlag.REFIT) + + if self.compilation_settings.strip_engine_weights: + builder_config.set_flag(trt.BuilderFlag.STRIP_PLAN) if strict_type_constraints: builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES) @@ -371,7 +384,6 @@ def find_weight( np_map: the map from weight name to np values in INetworkDefinition state_dict: state of the graph module """ - network_weight = np_map[weight_name] network_weight = torch.from_numpy(np_map[weight_name]).cuda() for sd_w_name, sd_weight in state_dict.items(): if TRTInterpreter.check_weight_equal(sd_weight, network_weight): @@ -460,6 +472,7 @@ def _save_weight_mapping(self) -> None: sd = {k: v.reshape(-1) for k, v in self.module.state_dict().items()} weight_name_map: dict[str, Any] = {} np_map = {} + constant_mapping = {} net = self.ctx.net for i in range(net.num_layers): layer = net[i] @@ -485,19 +498,22 @@ def _save_weight_mapping(self) -> None: suffix = sd_weight_name_list[-1] # Retrieve each weight name(s) in state_dict if layer_type == "CONSTANT": - if "embedding" in suffix: - sd_weight_name = f"{sd_weight_name}.weight" - elif "weight" in suffix or "mm_other" in suffix: - # Linear layer weight + if ( + "embedding" in suffix + or "weight" in suffix + or "mm_other" in suffix + ): sd_weight_name = f"{sd_weight_name}.weight" elif "running_mean" in suffix: - # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_mean" elif "running_var" in suffix: - # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_var" - else: + elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" + else: + # Save the constant weights for future fast refit + sd_weight_name = f"{sd_weight_name}.unknown" + constant_mapping[engine_weight_name] = weight elif layer_type == "SCALE": # Batch norm needs all weights to calculate scale and shift sd_weight_name = [f"{sd_weight_name}.{n}" for n in torch_attr] @@ -518,18 +534,126 @@ def _save_weight_mapping(self) -> None: weight_name_map[engine_weight_name] = TRTInterpreter.find_weight( engine_weight_name, np_map, sd ) + if ( + weight_name_map[engine_weight_name] != "" + and engine_weight_name in constant_mapping + ): + # If the weight is found in state_dict, remove it from constant_mapping + del constant_mapping[engine_weight_name] weight_name_map[engine_weight_name] = [ weight_name_map[engine_weight_name], np_map[engine_weight_name].dtype, ] + weight_name_map["constant_mapping"] = constant_mapping self.weight_name_map = weight_name_map del np_map, sd gc.collect() torch.cuda.empty_cache() + def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: + # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine + # if not self.compilation_settings.strip_engine_weights: + # # set EXCLUDE_WEIGHTS flag to strip weights + # runtime = trt.Runtime(TRT_LOGGER) + # engine = runtime.deserialize_cuda_engine(serialized_engine) + + # serialization_config = engine.create_serialization_config() + # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) + # serialized_engine = engine.serialize_with_config( + # serialization_config + # ) + + # Cache weighted engine for now + self.engine_cache.insert( # type: ignore[union-attr] + hash_val, + ( + serialized_engine, + self._input_names, + self._output_names, + self.input_specs, + self.compilation_settings, + self.weight_name_map, + ), + ) + + def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]: + # query the cached TRT engine + cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr] + if cached_data is not None: # hit the cache + ( + serialized_engine, + self._input_names, + self._output_names, + cached_engine_input_specs, + engine_compilation_settings, + self.weight_name_map, + ) = cached_data + + setting_compatiblity, incompattible_settings = settings_are_compatible( + self.compilation_settings, engine_compilation_settings + ) + assert ( + setting_compatiblity + ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" + + for i, e in enumerate( + [ + Input.equivalent_spec(c, i) + for c, i in zip(cached_engine_input_specs, self.input_specs) + ] + ): + assert ( + e + ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}" + + _LOGGER.info( + "Found the cached engine that corresponds to this graph. It is directly loaded." + ) + + # refit the cached engine with the new graph module + if not self.compilation_settings.strip_engine_weights: + runtime = trt.Runtime(TRT_LOGGER) + engine = runtime.deserialize_cuda_engine(serialized_engine) + + from torch_tensorrt.dynamo._refit import ( + _refit_single_trt_engine_with_gm, + ) + + _refit_single_trt_engine_with_gm( + new_gm=self.module, + old_engine=engine, + input_list=self.input_specs, + settings=self.compilation_settings, + weight_name_map=self.weight_name_map, + ) + serialized_engine = engine.serialize() + + # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine + # # EXCLUDE_WEIGHTS flag must be cleared + # serialization_config = engine.create_serialization_config() + # serialization_config.clear_flag( + # trt.SerializationFlag.EXCLUDE_WEIGHTS + # ) + # serialized_engine = engine.serialize_with_config( + # serialization_config + # ) + # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller + + with io.BytesIO() as engine_bytes: + engine_bytes.write(serialized_engine) + engine_str = engine_bytes.getvalue() + + return TRTInterpreterResult( + engine_str, + self._input_names, + self._output_names, + self.weight_name_map, + ) + return None + def run( self, strict_type_constraints: bool = False, @@ -548,7 +672,10 @@ def run( # self.engine_cache could be None if: # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or # 2) both cache_built_engines and reuse_cached_engines are False - if self.engine_cache is not None: + if ( + self.engine_cache is not None + and not self.compilation_settings.immutable_weights + ): if ( self.compilation_settings.cache_built_engines or self.compilation_settings.reuse_cached_engines @@ -557,75 +684,14 @@ def run( self.module, self.input_specs, self.compilation_settings ) - if self.compilation_settings.reuse_cached_engines: - # query the cached TRT engine - cached_data = self.engine_cache.check(hash_val) - if cached_data is not None: # hit the cache - ( - serialized_engine, - self._input_names, - self._output_names, - cached_engine_input_specs, - engine_compilation_settings, - self.weight_name_map, - ) = cached_data - - setting_compatiblity, incompattible_settings = ( - settings_are_compatible( - self.compilation_settings, engine_compilation_settings - ) - ) - assert ( - setting_compatiblity - ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" - - for i, e in enumerate( - [ - Input.equivalent_spec(c, i) - for c, i in zip(cached_engine_input_specs, self.input_specs) - ] - ): - assert ( - e - ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}" - - _LOGGER.info( - "Found the cached engine that corresponds to this graph. It is directly loaded." - ) - - runtime = trt.Runtime(TRT_LOGGER) - engine = runtime.deserialize_cuda_engine(serialized_engine) - - from torch_tensorrt.dynamo._refit import ( - _refit_single_trt_engine_with_gm, - ) - - # TODO: Fast refit is problematic for now. It will fail if the engine has batch_norm layers. - # We set weight_name_map=None to use slow refit anyway for now. Will fix it in the future. - _refit_single_trt_engine_with_gm( - new_gm=self.module, - old_engine=engine, - input_list=self.input_specs, - settings=self.compilation_settings, - weight_name_map=None, - ) - - serialized_engine = engine.serialize() - - with io.BytesIO() as engine_bytes: - engine_bytes.write(serialized_engine) - engine_str = engine_bytes.getvalue() - - return TRTInterpreterResult( - engine_str, - self._input_names, - self._output_names, - self.weight_name_map, - ) + if self.compilation_settings.reuse_cached_engines: + interpreter_result = self._pull_cached_engine(hash_val) + if interpreter_result is not None: # hit the cache + return interpreter_result self._construct_trt_network_def() - if self.compilation_settings.make_refittable: + if not self.compilation_settings.immutable_weights: self._save_weight_mapping() build_engine_start_time = datetime.now() @@ -652,28 +718,24 @@ def run( self._save_timing_cache( builder_config, self.compilation_settings.timing_cache_path ) + + # Engine caching only for refittable engines if ( - self.engine_cache is not None + not self.compilation_settings.immutable_weights and self.compilation_settings.cache_built_engines + and self.engine_cache is not None ): - self.engine_cache.insert( - hash_val, - ( - serialized_engine, - self._input_names, - self._output_names, - self.input_specs, - self.compilation_settings, - self.weight_name_map, - ), - ) + self._insert_engine_to_cache(hash_val, serialized_engine) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) engine_str = engine_bytes.getvalue() return TRTInterpreterResult( - engine_str, self._input_names, self._output_names, self.weight_name_map + engine_str, + self._input_names, + self._output_names, + self.weight_name_map, ) def run_node(self, n: torch.fx.Node) -> torch.fx.Node: diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 884c51e8ea..4d2f97de1c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -49,7 +49,9 @@ def get_ir(target: Target) -> SourceIR: return SourceIR.UNKNOWN -def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool: +def one_user_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: # Validate only one user, which is a getitem node that accesses the first element in the list return ( len(node.users) == 1 @@ -131,7 +133,6 @@ def aten_ops_batch_norm_legit_no_training( @dynamo_tensorrt_converter( torch.ops.aten.native_layer_norm.default, - capability_validator=one_user_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( @@ -265,9 +266,11 @@ def aten_ops_embedding( ) -def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool: +def embedding_bag_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: # Embedding bag op is not refitable - if settings.make_refittable: + if settings and not settings.immutable_weights: return False if not one_user_validator(node): @@ -415,7 +418,9 @@ def aten_ops_symsize_int( return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1]) -def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool: +def index_dtype_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: index = node.args[1] for ind in index: if ind is not None: @@ -841,7 +846,9 @@ def aten_ops_select( ) -def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool: +def index_put_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: if args_bounds_check(node.args, 3, False): # Check if accumulate is valid _LOGGER.debug("We do not support accumulate=True for aten.index_put operation") accumulate_valid = False @@ -928,9 +935,9 @@ def aten_ops_slice( ) -def refit_validator(node: Node, settings: CompilationSettings = None) -> bool: +def refit_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: # cumsum op is not refitable - if settings and settings.make_refittable: + if settings and not settings.immutable_weights: return False return True @@ -985,7 +992,9 @@ def aten_ops_tile( ) -def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool: +def zero_output_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: if 0 in node.args[1]: _LOGGER.debug( f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation." @@ -999,7 +1008,6 @@ def zero_output_validator(node: Node, settings: CompilationSettings = None) -> b torch.ops.aten.as_strided.default, capability_validator=zero_output_validator, ) -@dynamo_tensorrt_converter(torch.ops.aten.as_strided.default) def aten_ops_as_strided( ctx: ConversionContext, target: Target, @@ -1043,7 +1051,7 @@ def aten_ops_permute( def to_copy_dtype_validator( - placeholder_only: bool, settings: CompilationSettings = None + placeholder_only: bool, settings: Optional[CompilationSettings] = None ) -> Callable[[Node, CompilationSettings], bool]: """Return validator for to_copy node with placeholder restrictions""" @@ -1076,7 +1084,9 @@ def validate_dtype(to_copy_node: Node) -> bool: ) return False - def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool: + def validator( + to_copy_node: Node, settings: Optional[CompilationSettings] = None + ) -> bool: """Returns true if the to_copy node can be converted to TRT and the placeholder restriction is satisfied """ @@ -2045,7 +2055,6 @@ def aten_ops_div( @dynamo_tensorrt_converter( torch.ops.aten.pow.Tensor_Scalar, supports_dynamic_shapes=True ) -@dynamo_tensorrt_converter(operator.pow, supports_dynamic_shapes=True) def aten_ops_pow( ctx: ConversionContext, target: Target, @@ -2147,7 +2156,9 @@ def aten_ops_logical_xor( ) -def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool: +def bitwise_type_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: supported_type = [torch.bool, bool] tensor_targets = [ @@ -2291,7 +2302,7 @@ def aten_ops_bitwise_xor( def bitwise_not_type_validator( - node: Node, settings: CompilationSettings = None + node: Node, settings: Optional[CompilationSettings] = None ) -> bool: val = node.args[0] val_meta = val.meta.get("tensor_meta") @@ -2474,7 +2485,9 @@ def aten_ops_le( ) -def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool: +def conv_param_validator( + conv_node: Node, settings: Optional[CompilationSettings] = None +) -> bool: return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) @@ -2571,7 +2584,7 @@ def aten_ops_cdist_forward( def avg_pool_param_validator( - pool_node: Node, settings: CompilationSettings = None + pool_node: Node, settings: Optional[CompilationSettings] = None ) -> bool: ceil_mode = args_bounds_check(pool_node.args, 4, False) divisor_override = args_bounds_check(pool_node.args, 6) @@ -2688,12 +2701,12 @@ def aten_ops_adaptive_avg_poolNd( ) -def topk_validator(node: Node, settings: CompilationSettings = None) -> bool: +def topk_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: k = node.args[1] return topk_sort_validator(k) -def sort_validator(node: Node, settings: CompilationSettings = None) -> bool: +def sort_validator(node: Node, settings: Optional[CompilationSettings] = None) -> bool: meta_data = node.args[0].meta.get("tensor_meta") if meta_data is None: return False @@ -2716,7 +2729,7 @@ def topk_sort_validator(k: int) -> bool: def max_pool_param_validator( - pool_node: Node, settings: CompilationSettings = None + pool_node: Node, settings: Optional[CompilationSettings] = None ) -> bool: dilation = args_bounds_check(pool_node.args, 4, 1) ceil_mode = args_bounds_check(pool_node.args, 5, False) @@ -2771,7 +2784,9 @@ def aten_ops_max_pool( ) -def attention_validator(node: Node, settings: CompilationSettings = None) -> bool: +def attention_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: # Currently, `attn_mask` is not supported return args_bounds_check(node.args, 3) is None @@ -3309,7 +3324,6 @@ def aten_ops_copy( @dynamo_tensorrt_converter( torch.ops.aten.remainder.Tensor, supports_dynamic_shapes=True ) -@dynamo_tensorrt_converter(operator.mod, supports_dynamic_shapes=True) @enforce_tensor_types( { 0: (TRTTensor,), @@ -3401,7 +3415,9 @@ def aten_ops_flip( ) -def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool: +def zero_diag_size_validator( + node: Node, settings: Optional[CompilationSettings] = None +) -> bool: meta = node.args[0].meta.get("tensor_meta") if meta: input_shape = meta.shape @@ -3530,7 +3546,7 @@ def aten_ops_index_select( def dropout_inference_validator( - node: Node, settings: CompilationSettings = None + node: Node, settings: Optional[CompilationSettings] = None ) -> bool: train_mode = args_bounds_check(node.args, 2, None) if train_mode is False: diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 28f0954185..134d84cf6d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -65,7 +65,7 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refittable: bool = _defaults.MAKE_REFITTABLE, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -103,7 +103,7 @@ def __init__( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -152,8 +152,8 @@ def __init__( device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} assert ( - make_refittable - ), "'make_refittable' has to be True for a MutableTorchTensorRTModule." + not immutable_weights + ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." compilation_options = { "enabled_precisions": ( enabled_precisions @@ -180,7 +180,7 @@ def __init__( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refittable": make_refittable, + "immutable_weights": immutable_weights, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e31d73f337..ffe7e9e03a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -38,7 +38,7 @@ def __init__( *, name: str = "", settings: CompilationSettings = CompilationSettings(), - weight_name_map: Any = None, + weight_name_map: Optional[dict[Any, Any]] = None, ): """Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine @@ -51,6 +51,7 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name Example: diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 1bebe20fda..d7cfc6608b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -96,6 +96,7 @@ def __init__( Keyword Arguments: name (str): Name for module settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed + weight_name_map (dict): Mapping of engine weight name to state_dict weight name Example: diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 95e5f30e4d..187b9472b1 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -558,10 +558,6 @@ def parse_dynamo_kwargs( engine_cache = None if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"): - assert kwargs.get( - "make_refittable" - ), "Engine caching requires make_refittable to be set to True" - if kwargs.get("custom_engine_cache") is not None: engine_cache = kwargs.get("custom_engine_cache") else: diff --git a/tests/py/dynamo/conversion/harness.py b/tests/py/dynamo/conversion/harness.py index 61f891267e..26818acd8a 100644 --- a/tests/py/dynamo/conversion/harness.py +++ b/tests/py/dynamo/conversion/harness.py @@ -403,7 +403,7 @@ def run_test( enable_passes=False, propagate_shapes=False, int32_reqd=False, - make_refittable=False, + immutable_weights=True, ): # TODO: lan to remove this and set use_dynamo_traccer to True by default # once all the converter test files are moved to use_dynamo_tracer @@ -414,7 +414,7 @@ def run_test( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, - make_refittable=make_refittable, + immutable_weights=immutable_weights, ) mod = self.generate_graph( @@ -498,7 +498,7 @@ def run_test_compare_tensor_attributes_only( output_dtypes=None, use_dynamo_tracer=False, enable_passes=False, - make_refittable=False, + immutable_weights=True, ): # Previous instance of the interpreter auto-casted 64-bit inputs @@ -507,7 +507,7 @@ def run_test_compare_tensor_attributes_only( enabled_precisions={dtype._from(precision)}, truncate_double=True, debug=True, - make_refittable=make_refittable, + immutable_weights=immutable_weights, ) mod = self.generate_graph( @@ -541,7 +541,7 @@ def run_test_with_dynamic_shape( pyt_inputs=None, propagate_shapes=False, check_dtype=True, - make_refittable=False, + immutable_weights=True, torch_export_dynamic_shapes=None, ): # TODO: lan to remove this and set use_dynamo_traccer to True by default @@ -551,7 +551,8 @@ def run_test_with_dynamic_shape( # Previous instance of the interpreter auto-casted 64-bit inputs # We replicate this behavior here compilation_settings = CompilationSettings( - truncate_double=True, make_refittable=make_refittable + truncate_double=True, + immutable_weights=immutable_weights, ) mod = self.generate_graph( mod, diff --git a/tests/py/dynamo/conversion/test_chunk_aten.py b/tests/py/dynamo/conversion/test_chunk_aten.py deleted file mode 100644 index eb06c04201..0000000000 --- a/tests/py/dynamo/conversion/test_chunk_aten.py +++ /dev/null @@ -1,187 +0,0 @@ -import unittest - -import torch -from parameterized import parameterized -from torch.testing._internal.common_utils import run_tests -from torch_tensorrt import Input - -from .harness import DispatchTestCase - - -class TestChunkConverter(DispatchTestCase): - @parameterized.expand( - [ - ((1,), 3, 0), - ((3,), 3, 0), - ((4,), 3, 0), - ((6,), 3, 0), - ((3,), 1, -1), - ((3,), 3, -1), - ((3,), 4, -1), - ] - ) - def test_chunk_1D(self, shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input = [torch.randn(shape)] - self.run_test( - TestChunk(), - input, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4), 1, 0), - ((3, 4), 3, 0), - ((3, 4), 4, 0), - ((3, 4), 2, -2), - ((3, 4), 6, -2), - ((3, 4), 3, 1), - ((3, 4), 4, 1), - ((3, 4), 5, -1), - ] - ) - def test_chunk_2D(self, shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input = [torch.randn(shape)] - self.run_test( - TestChunk(), - input, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4, 2), 1, 0), - ((3, 4, 2), 3, -3), - ((3, 4, 2), 3, 1), - ((3, 4, 2), 4, 1), - ((3, 4, 2), 6, -2), - ((3, 4, 2), 1, 2), - ((3, 4, 2), 3, -1), - ((3, 4, 2), 4, -1), - ] - ) - def test_chunk_3D(self, shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input = [torch.randn(shape)] - self.run_test( - TestChunk(), - input, - use_dynamo_tracer=True, - ) - - -#######################Dynamic cases####################### -# The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed -@unittest.skip( - "Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663" -) -class TestChunkDynamicConverter(DispatchTestCase): - @parameterized.expand( - [ - ((1,), (1,), (3,), 3, 0), - ((3,), (3,), (4,), 3, 0), - ((4,), (4,), (6,), 3, 0), - ((6,), (6,), (9,), 3, 0), - ((3,), (3,), (4,), 1, -1), - ((3,), (3,), (4,), 3, -1), - ((3,), (3,), (4,), 4, -1), - ] - ) - def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input_specs = [ - Input( - min_shape=min_shape, - opt_shape=opt_shape, - max_shape=max_shape, - ), - ] - self.run_test_with_dynamic_shape( - TestChunk(), - input_specs, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4), (3, 4), (4, 4), 1, 0), - ((3, 4), (3, 4), (4, 4), 3, 0), - ((3, 4), (3, 4), (4, 4), 4, 0), - ((3, 4), (3, 4), (4, 4), 2, -2), - ((3, 4), (3, 4), (4, 4), 6, -2), - ((3, 4), (3, 4), (4, 4), 3, 1), - ((3, 4), (3, 4), (4, 4), 4, 1), - ((3, 4), (3, 4), (4, 4), 5, -1), - ] - ) - def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input_specs = [ - Input( - min_shape=min_shape, - opt_shape=opt_shape, - max_shape=max_shape, - ), - ] - self.run_test_with_dynamic_shape( - TestChunk(), - input_specs, - use_dynamo_tracer=True, - ) - - @parameterized.expand( - [ - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1), - ((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1), - ] - ) - def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim): - class TestChunk(torch.nn.Module): - def forward(self, input): - out = torch.ops.aten.chunk.default(input, chunks, dim) - return out - - input_specs = [ - Input( - min_shape=min_shape, - opt_shape=opt_shape, - max_shape=max_shape, - ), - ] - self.run_test_with_dynamic_shape( - TestChunk(), - input_specs, - use_dynamo_tracer=True, - ) - - -if __name__ == "__main__": - run_tests() diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index 1c32be6dd6..8ab699468d 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,7 +24,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refittable=False, + immutable_weights=True, ) @parameterized.expand( @@ -44,7 +44,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refittable=False, + immutable_weights=True, ) @parameterized.expand( @@ -65,7 +65,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, - make_refittable=False, + immutable_weights=True, ) @parameterized.expand( @@ -95,7 +95,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, - make_refittable=False, + immutable_weights=True, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 6543ac2306..1f119bd77e 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,7 +148,7 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refittable=False, + immutable_weights=True, ) @parameterized.expand( @@ -346,7 +346,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refittable=False, + immutable_weights=True, ) @parameterized.expand( @@ -411,7 +411,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, - make_refittable=False, + immutable_weights=True, ) @parameterized.expand( @@ -493,7 +493,7 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, - make_refittable=False, + immutable_weights=True, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/conversion/test_group_norm_aten.py b/tests/py/dynamo/conversion/test_group_norm_aten.py index 617166d0c4..b62be920f9 100644 --- a/tests/py/dynamo/conversion/test_group_norm_aten.py +++ b/tests/py/dynamo/conversion/test_group_norm_aten.py @@ -135,10 +135,10 @@ def forward(self, x): @parameterized.expand( [ - (5, 4, 4, 2, (2, 4, 2), (3, 4, 2), (5, 4, 4)), - (5, 4, 2 * 2, 2, (2, 4, 2, 2), (3, 4, 2, 2), (5, 4, 2, 2)), - (5, 9, 6 * 3, 3, (3, 9, 3, 3), (4, 9, 3, 3), (5, 9, 6, 3)), - (8, 9, 6 * 6, 3, (3, 9, 2, 3, 2), (5, 9, 3, 3, 2), (8, 9, 6, 3, 2)), + (5, 4, 4, 2, (2, 4, 2), (5, 4, 4), (5, 4, 4)), + (5, 4, 2 * 2, 2, (2, 4, 2, 2), (5, 4, 2, 2), (5, 4, 2, 2)), + (5, 9, 6 * 3, 3, (3, 9, 3, 3), (5, 9, 6, 3), (5, 9, 6, 3)), + (8, 9, 6 * 6, 3, (3, 9, 2, 3, 2), (8, 9, 6, 3, 2), (8, 9, 6, 3, 2)), ] ) def test_groupnorm_with_dynamic_shape( diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 5ceea5e381..68451674c5 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -74,7 +74,7 @@ def test_reexport_is_equal(self): ), ) settings1 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -89,7 +89,7 @@ def test_reexport_is_equal(self): ), ) settings2 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -111,7 +111,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings1 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -126,7 +126,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings2 = CompilationSettings( - make_refittable=True, cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -148,7 +148,7 @@ def test_engine_settings_is_not_equal(self): ), ) settings1 = CompilationSettings( - make_refittable=True, + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32}, @@ -166,7 +166,7 @@ def test_engine_settings_is_not_equal(self): ), ) settings2 = CompilationSettings( - make_refittable=True, + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32, torch.float16}, @@ -206,6 +206,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement remove_timing_cache() torch._dynamo.reset() if i == 0: @@ -220,11 +221,11 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -284,11 +285,11 @@ def test_dynamo_compile_with_custom_engine_cache(self): trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), - use_python_runtime=False, + use_python_runtime=True, enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, custom_engine_cache=custom_engine_cache, @@ -335,7 +336,7 @@ def test_dynamo_compile_change_input_shape(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refittable=True, + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, ) @@ -386,11 +387,11 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): model, backend="tensorrt", options={ - "use_python_runtime": True, + "use_python_runtime": False, "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, @@ -400,7 +401,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): results.append(compiled_model(*inputs)) # trigger the compilation end.record() torch.cuda.synchronize() - torch._dynamo.reset() times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) @@ -451,11 +451,11 @@ def test_torch_compile_with_custom_engine_cache(self): model, backend="tensorrt", options={ - "use_python_runtime": True, + "use_python_runtime": False, "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, @@ -487,18 +487,59 @@ def test_torch_compile_with_custom_engine_cache(self): for h, count in custom_engine_cache.hashes.items() ] - def test_torch_compile_change_input_shape(self): + def test_torch_trt_compile_change_input_shape(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") - - engine_cache_dir = "/tmp/test_torch_compile_with_default_disk_engine_cache" + engine_cache_dir = "/tmp/test_torch_trt_compile_change_input_shape" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) custom_engine_cache = MyEngineCache(engine_cache_dir) for i in range(3): - # remove timing cache and reset dynamo for engine caching messurement inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")] + compiled_model = torch_trt.compile( + model, + inputs=inputs, + **{ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "immutable_weights": False, + "cache_built_engines": True, + "reuse_cached_engines": True, + "custom_engine_cache": custom_engine_cache, + }, + ) + compiled_model(*inputs) + [ + assertions.assertTrue( + count == 0, f"Unintended cache hit for entry ({h}, hit: {count})" + ) + for h, count in custom_engine_cache.hashes.items() + ] + + def test_torch_compile_graph_break(self): + class MyModel(torch.nn.Module): + def forward(self, x): + x = x + x + x = x + x + x = torch.ops.aten.relu.default(x) + x = x + x + x = x + x + x = torch.ops.aten.relu.default(x) + x = x + x + x = x + x + return x + + model = MyModel().eval().cuda() + engine_cache_dir = "/tmp/test_torch_compile_graph_break" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + inputs = [torch.rand((3, 3, 224, 224)).to("cuda")] + for i in range(3): compiled_model = torch.compile( model, backend="tensorrt", @@ -507,17 +548,460 @@ def test_torch_compile_change_input_shape(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refittable": True, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "custom_engine_cache": custom_engine_cache, "torch_executed_ops": {"torch.ops.aten.relu.default"}, }, ) + compiled_model(*inputs) [ assertions.assertTrue( - count == 0, f"Unintended cache hit for entry ({h}, hit: {count})" + count == 2, + f"cache was not hit exactly twice for entry ({h}, hit: {count})", ) for h, count in custom_engine_cache.hashes.items() ] + + def test_isomorphic_graphs(self): + class MyModel1(torch.nn.Module): + def forward(self, a, b): + return a + b + + class MyModel2(torch.nn.Module): + def forward(self, c, d): + return c + d + + model1 = MyModel1().eval().cuda() + model2 = MyModel2().eval().cuda() + + inputs1 = (torch.randn((2, 3)).to("cuda"), torch.randn((2, 3)).to("cuda")) + inputs2 = (torch.randn((2, 3)).to("cuda"), torch.randn((2, 3)).to("cuda")) + + exp_program1 = torch.export.export(model1, args=inputs1) + exp_program2 = torch.export.export(model2, args=inputs2) + + input_specs1 = ( + torch_trt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(10, 3), + ), + ) + + input_specs2 = ( + torch_trt.Input( + min_shape=(1, 3), + opt_shape=(2, 3), + max_shape=(10, 3), + ), + ) + + settings1 = CompilationSettings( + cache_built_engines=True, reuse_cached_engines=True + ) + + settings2 = CompilationSettings( + cache_built_engines=True, reuse_cached_engines=True + ) + + hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) + hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) + + assertions.assertEqual(hash1, hash2) + + # @unittest.skip("benchmark on small models") + def test_caching_small_model(self): + from torch_tensorrt.dynamo._refit import refit_module_weights + + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/test_caching_small_model" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + inputs = (torch.rand((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(model, args=inputs) + + # warm up + trt_gm = torch_trt.dynamo.compile( + exp_program, + inputs, + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + immutable_weights=False, + cache_built_engines=False, + reuse_cached_engines=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + torch.cuda.empty_cache() + + compile_times = [[] for _ in range(3)] + inference_times = [[] for _ in range(3)] + results = [[] for _ in range(3)] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + interval = 3 + for i in range(interval * 3): + if i < interval: + # non-refittable + immutable_weights = True + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = False + # continue + elif i < interval * 2: + # REFIT w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = True + # continue + else: + # REFIT_IDENTICAL w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = True + cache_built_engines = reuse_cached_engines = True + # continue + + if i % interval == 0: + remove_timing_cache() + + torch._dynamo.reset() + + torch.cuda.synchronize() + start.record() + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + engine_cache_size=1 << 40, + immutable_weights=immutable_weights, + strip_engine_weights=strip_engine_weights, + refit_identical_engine_weights=refit_identical_engine_weights, + ) + + if strip_engine_weights: + trt_gm = refit_module_weights(trt_gm, exp_program) + + end.record() + torch.cuda.synchronize() + compile_times[i // interval].append(start.elapsed_time(end)) + + # inference + torch.cuda.synchronize() + start.record() + out = trt_gm(*inputs) + end.record() + torch.cuda.synchronize() + inference_times[i // interval].append(start.elapsed_time(end)) + + results[i // interval].append(out) + + torch.cuda.empty_cache() + + cos_sim = cosine_similarity(torch.stack(results[0]), torch.stack(results[1])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(torch.stack(results[1]), torch.stack(results[2])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][1]} ms", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][1]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[1][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[2][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) + + @unittest.skip("benchmark on llama2") + def test_caching_llama2_model(self): + import torch + from torch_tensorrt.dynamo._refit import refit_module_weights + from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + StoppingCriteriaList, + ) + from transformers.generation.stopping_criteria import ( + EosTokenCriteria, + MaxLengthCriteria, + ) + + def export_llm(model, inputs, min_seq_len=1, max_seq_len=16): + """ + Exports the LLM model into an ExportedProgram with dynamic shapes. + In the case of guard failures due to some PyTorch kernel implements, we also + try to re-export the graph by expressing them as runtime assert nodes + """ + with torch.no_grad(): + # max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604 + seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len) + try: + print("Trying to export the model using torch.export.export()..") + # strict=False only enables aotautograd tracing and excludes dynamo. + ep = torch.export.export( + model, (inputs,), dynamic_shapes=({1: seq_len},), strict=False + ) + except: + print( + "Trying torch.export._trace._export to trace the graph since torch.export.export() failed" + ) + # This API is used to express the constraint violation guards as asserts in the graph. + ep = torch.export._trace._export( + model, + (inputs,), + dynamic_shapes=({1: seq_len},), + strict=False, + allow_complex_guards_as_runtime_asserts=True, + ) + + return ep + + def generate(model, input_seq, max_tokens, eos_token_id): + """ + Greedy decoding of the model. This generates up to max_tokens. + """ + # Max length of output seq = current input_seq length + max_tokens allowed to generate + max_output_seq_length = input_seq.shape[1] + max_tokens + stopping_criteria = StoppingCriteriaList( + [ + MaxLengthCriteria(max_length=max_output_seq_length), + EosTokenCriteria(eos_token_id=eos_token_id), + ] + ) + + while True: + outputs = model(input_seq) + logits = outputs.logits + next_token_logits = logits[:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + input_seq = torch.cat([input_seq, next_tokens[:, None]], dim=-1) + # TODO: Handle batch in this check + if stopping_criteria(input_seq, logits).item(): + break + + return input_seq + + MAX_TOKENS = 32 + DEVICE = torch.device("cuda:0") + + llama_path = "meta-llama/Llama-2-7b-chat-hf" + with torch.no_grad(): + model = AutoModelForCausalLM.from_pretrained( + llama_path, use_cache=False, attn_implementation="eager" + ).eval() + + tokenizer = AutoTokenizer.from_pretrained(llama_path) + + prompt = "What is dynamic programming?" + model_inputs = tokenizer(prompt, return_tensors="pt") + input_ids = model_inputs.input_ids + + llama2_ep = export_llm(model, input_ids, max_seq_len=64) + + engine_cache_dir = "/tmp/test_caching_llama2_model" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + timing_cache_path = os.path.join( + engine_cache_dir, "llama2_timing_cache_original.bin" + ) + + def remove_timing_cache(path=timing_cache_path): + if os.path.exists(path): + os.remove(path) + + input_ids = input_ids.to(DEVICE) + + # warm up + trt_gm = torch_trt.dynamo.compile( + llama2_ep, + inputs=[input_ids], + use_python_runtime=True, + enabled_precisions={torch.float32}, + debug=False, + min_block_size=1, + immutable_weights=False, + truncate_double=True, + device=DEVICE, + disable_tf32=True, + cache_built_engines=False, + reuse_cached_engines=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + timing_cache_path=timing_cache_path, + ) + torch.cuda.empty_cache() + + compile_times = [[] for _ in range(3)] + inference_times = [[] for _ in range(3)] + results = [[] for _ in range(3)] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + interval = 3 + for i in range(interval * 3): + if i < interval: + # non-refittable + immutable_weights = True + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = False + elif i < interval * 2: + # REFIT w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = False + cache_built_engines = reuse_cached_engines = True + else: + # REFIT_IDENTICAL w/ engine caching + immutable_weights = False + strip_engine_weights = False + refit_identical_engine_weights = True + cache_built_engines = reuse_cached_engines = True + + if i % interval == 0: + remove_timing_cache() + + torch._dynamo.reset() + + torch.cuda.synchronize() + start.record() + + trt_gm = torch_trt.dynamo.compile( + llama2_ep, + inputs=[input_ids], + use_python_runtime=True, + enabled_precisions={torch.float32}, + debug=False, + min_block_size=1, + truncate_double=True, + device=DEVICE, + disable_tf32=True, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + engine_cache_size=1 << 40, + immutable_weights=immutable_weights, + strip_engine_weights=strip_engine_weights, + refit_identical_engine_weights=refit_identical_engine_weights, + timing_cache_path=timing_cache_path, + ) + + if strip_engine_weights: + trt_gm = refit_module_weights(trt_gm, llama2_ep) + + end.record() + torch.cuda.synchronize() + + compile_times[i // interval].append(start.elapsed_time(end)) + + # inference + torch.cuda.synchronize() + start.record() + + trt_gen_tokens = generate( + trt_gm, input_ids, MAX_TOKENS, tokenizer.eos_token_id + ) + # trt_gen_text = tokenizer.batch_decode( + # trt_gen_tokens, + # skip_special_tokens=True, + # clean_up_tokenization_spaces=False, + # )[0], + results[i // interval].append(trt_gen_tokens) + + end.record() + torch.cuda.synchronize() + + inference_times[i // interval].append(start.elapsed_time(end)) + + torch.cuda.empty_cache() + + cos_sim = cosine_similarity(torch.stack(results[0]), torch.stack(results[1])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(torch.stack(results[1]), torch.stack(results[2])) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][1]} ms", + ) + + assertions.assertTrue( + compile_times[1][0] > compile_times[1][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[1][0]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][1], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][1]} ms", + ) + + assertions.assertTrue( + compile_times[2][0] > compile_times[2][2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {compile_times[2][0]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[1][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[1][2]} ms", + ) + + assertions.assertTrue( + compile_times[0][2] > compile_times[2][2], + msg=f"Engine caching is slower than recompiling a non-refittable engine. Recompile a non-refittable engine: {compile_times[0][2]} ms, time taken with engine caching: {compile_times[2][2]} ms", + ) diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 07a9353037..bb61ac2d43 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -1,9 +1,7 @@ import os import tempfile -import time import unittest -import numpy as np import pytest import tensorrt as trt import torch @@ -57,8 +55,7 @@ def test_mapping(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -110,8 +107,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -161,8 +157,7 @@ def test_refit_one_engine_no_map_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) trt_gm._run_on_acc_0.weight_name_map = None @@ -213,8 +208,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) # Manually Deleted all batch norm layer. This suppose to fail the fast refit trt_gm._run_on_acc_0.weight_name_map = { @@ -271,8 +265,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -325,8 +318,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) @@ -372,8 +364,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -443,7 +434,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, + immutable_weights=False, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -494,8 +485,7 @@ def test_refit_one_engine_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -546,8 +536,7 @@ def test_refit_one_engine_bert_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -600,8 +589,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) @@ -647,8 +635,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, - reuse_cached_engines=False, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -718,7 +705,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refittable=True, + immutable_weights=False, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -772,7 +759,7 @@ def forward(self, x): enabled_precisions={torch.float}, debug=True, min_block_size=1, - make_refittable=True, + immutable_weights=False, ) num_pyt_segments = len( diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py new file mode 100644 index 0000000000..0c79ba7a3f --- /dev/null +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -0,0 +1,564 @@ +import os +import pickle +import shutil +import unittest + +import torch +import torch_tensorrt as torch_trt +import torchvision.models as models +from torch.testing._internal.common_utils import TestCase +from torch_tensorrt.dynamo import convert_exported_program_to_serialized_trt_engine +from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH +from torch_tensorrt.dynamo._refit import refit_module_weights +from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity + +assertions = unittest.TestCase() + + +class TestWeightStrippedEngine(TestCase): + def test_three_ways_to_compile(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + settings = { + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "immutable_weights": False, + "strip_engine_weights": False, + "refit_identical_engine_weights": False, + } + + # 1. Compile with torch_trt.dynamo.compile + gm1 = torch_trt.dynamo.compile( + exp_program, + example_inputs, + **settings, + ) + gm1_output = gm1(*example_inputs) + + # 2. Compile with torch.compile using tensorrt backend + gm2 = torch.compile( + pyt_model, + backend="tensorrt", + options=settings, + ) + gm2_output = gm2(*example_inputs) + + pyt_model_output = pyt_model(*example_inputs) + + assert torch.allclose( + pyt_model_output, gm1_output, 1e-2, 1e-2 + ), "gm1_output is not correct" + + assert torch.allclose( + gm1_output, gm2_output, 1e-2, 1e-2 + ), "gm2_output is not correct" + + def test_three_ways_to_compile_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + + settings = { + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "immutable_weights": False, + "strip_engine_weights": True, + "refit_identical_engine_weights": False, + } + + # 1. Compile with torch_trt.compile using dynamo backend + gm1 = torch_trt.compile( + pyt_model, ir="dynamo", inputs=example_inputs, **settings + ) + gm1_output = gm1(*example_inputs) + + # 2. Compile with torch.compile using tensorrt backend, which is not supported to set strip_engine_weights=True + # gm2 = torch.compile( + # pyt_model, + # backend="tensorrt", + # options=settings, + # ) + # gm2_output = gm2(*example_inputs) + + assertions.assertEqual( + gm1_output.sum(), 0, msg="gm1_output should be all zeros" + ) + + def test_weight_stripped_engine_sizes(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + weight_included_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + immutable_weights=False, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + weight_stripped_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + immutable_weights=False, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + weight_stripped_refit_identical_engine = ( + convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + immutable_weights=False, + strip_engine_weights=True, + refit_identical_engine_weights=True, + ) + ) + assertions.assertTrue( + len(bytes(weight_included_engine)) > len(bytes(weight_stripped_engine)), + msg=f"Weight-stripped engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped engine size: {len(bytes(weight_stripped_engine))}", + ) + assertions.assertTrue( + len(bytes(weight_included_engine)) + > len(bytes(weight_stripped_refit_identical_engine)), + msg=f"Weight-stripped refit-identical engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, weight-stripped refit-identical engine size: {len(bytes(weight_stripped_refit_identical_engine))}", + ) + + def test_weight_stripped_engine_results(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + # Mark the dim0 of inputs as dynamic + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + immutable_weights=False, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + output = trt_gm(*inputs) + assertions.assertEqual( + output.sum(), 0, msg="weight-stripped engine results should be all zeros" + ) + + # Refit the weight-stripped engine with the same weights + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + assertions.assertNotEqual( + refitted_output.sum(), + 0, + msg="refitted engine results should not be all zeros", + ) + + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "immutable_weights": False, + "cache_built_engines": False, + "reuse_cached_engines": False, + "refit_identical_engine_weights": False, + "strip_engine_weights": False, + }, + ) + compiled_model_output = compiled_model(*inputs) + cos_sim = cosine_similarity(refitted_output, compiled_model_output) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"refitted_output doesn't match with compiled_model_output. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @unittest.skip( + "For now, torch-trt will save weighted engine if strip_engine_weights is False. In the near future, we plan to save weight-stripped engine regardless of strip_engine_weights, which is pending on TRT's feature development: NVBug #4914602" + ) + def test_engine_caching_saves_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + engine_cache_dir = "/tmp/test_engine_caching_saves_weight_stripped_engine" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + weight_included_engine = convert_exported_program_to_serialized_trt_engine( + exp_program, + example_inputs, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(example_inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + immutable_weights=False, + strip_engine_weights=False, + refit_identical_engine_weights=True, + cache_built_engines=True, + reuse_cached_engines=True, + engine_cache_dir=engine_cache_dir, + ) + output = trt_gm(*example_inputs) + assertions.assertNotEqual(output.sum(), 0, msg="results shouldn't be all zeros") + + blob_path = os.path.join( + engine_cache_dir, os.listdir(engine_cache_dir)[0], "blob.bin" + ) + with open(blob_path, "rb") as f: + blob = f.read() + unpacked = pickle.loads(blob) + cached_stripped_engine = unpacked["serialized_engine"] + + assertions.assertTrue( + len(bytes(weight_included_engine)) > len(bytes(cached_stripped_engine)), + msg=f"cached engine size is not smaller than the weight included engine size. Weight included engine size: {len(bytes(weight_included_engine))}, cached stripped engine size: {len(bytes(cached_stripped_engine))}", + ) + + def test_dynamo_compile_with_refittable_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, args=example_inputs) + + engine_cache_dir = ( + "/tmp/test_dynamo_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + immutable_weights=False, + cache_built_engines=cache_built_engines, + reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir, + strip_engine_weights=False, + refit_identical_engine_weights=False, + ) + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + results.append(trt_gm(*inputs)) + + assertions.assertNotEqual( + results[0].sum(), 0, msg="results[0] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[1].sum(), 0, msg="results[1] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[2].sum(), 0, msg="results[2] shouldn't be all zeros" + ) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_torch_compile_with_refittable_weight_stripped_engine(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = ( + "/tmp/test_torch_compile_with_refittable_weight_stripped_engine" + ) + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + def remove_timing_cache(path=TIMING_CACHE_PATH): + if os.path.exists(path): + os.remove(path) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + results = [] + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + for i in range(3): + remove_timing_cache() + torch._dynamo.reset() + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + torch.cuda.synchronize() + start.record() + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": False, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "immutable_weights": False, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + "engine_cache_dir": engine_cache_dir, + "strip_engine_weights": False, + "refit_identical_engine_weights": True, + }, + ) + results.append(compiled_model(*inputs)) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + assertions.assertNotEqual( + results[0].sum(), 0, msg="results[0] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[1].sum(), 0, msg="results[1] shouldn't be all zeros" + ) + assertions.assertNotEqual( + results[2].sum(), 0, msg="results[2] shouldn't be all zeros" + ) + + cos_sim = cosine_similarity(results[0], results[1]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[0] doesn't match with results[1]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + cos_sim = cosine_similarity(results[1], results[2]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + assertions.assertTrue( + times[0] > times[2], + msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", + ) + + def test_different_args_dont_share_cached_engine(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 4, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = self.conv(x) + out = self.relu(out) + return out + + pyt_model = MyModel().eval().to("cuda") + + engine_cache_dir = "/tmp/test_different_args_dont_share_cached_engine" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + inputs = [torch.rand((4, 3, 32, 32)).to("cuda")] + + for i in range(2): + if i == 0: + strip_engine_weights = False + else: + strip_engine_weights = True + + compiled_model = torch.compile( + pyt_model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "immutable_weights": False, + "cache_built_engines": True, + "reuse_cached_engines": True, + "engine_cache_dir": engine_cache_dir, + "strip_engine_weights": strip_engine_weights, + }, + ) + compiled_model(*inputs) + + assertions.assertEqual( + len(os.listdir(engine_cache_dir)), + 2, + msg=f"It has {len(os.listdir(engine_cache_dir))} cached engine(s) but should have 2 engines", + ) + + def test_constant_mul_in_refitting(self): + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.tensor(0.5, requires_grad=False) + + def forward(self, x): + out = x * self.weight + return out + + pyt_model = MyModel().eval().cuda() + inputs = [torch.randn((1, 3, 4, 4)).to("cuda")] + + exp_program = torch.export.export(pyt_model, args=tuple(inputs)) + + trt_module = torch_trt.compile( + pyt_model, + ir="dynamo", + inputs=tuple(inputs), + min_block_size=1, + immutable_weights=False, + use_python_runtime=True, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + + refitted_trt_gm = refit_module_weights(trt_module, exp_program) + + outputs_pyt = pyt_model(*inputs) + outputs_trt = refitted_trt_gm(*inputs) + + cos_sim = cosine_similarity(outputs_pyt, outputs_trt) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + def test_two_TRTRuntime_in_refitting(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + batch = torch.export.Dim("batch", min=1, max=200) + exp_program = torch.export.export( + pyt_model, args=example_inputs, dynamic_shapes={"x": {0: batch}} + ) + inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] + + pyt_results = pyt_model(*inputs) + + for i in range(2): + if i == 0: + use_python_runtime = True + else: + use_python_runtime = False + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(inputs), + use_python_runtime=use_python_runtime, + debug=False, + min_block_size=1, + immutable_weights=False, + strip_engine_weights=True, + refit_identical_engine_weights=False, + ) + + output = trt_gm(*inputs) + assertions.assertEqual(output.sum(), 0, msg="results should be all zeros") + + refitted_trt_gm = refit_module_weights(trt_gm, exp_program) + refitted_output = refitted_trt_gm(*inputs) + cos_sim = cosine_similarity(pyt_results, refitted_output) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"{'PythonTorchTensorRTModule' if use_python_runtime else 'TorchTensorRTModule'} outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @unittest.skip("Waiting for implementation") + def test_refit_identical_engine_weights(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + exp_program = torch.export.export(pyt_model, example_inputs) + + engine_cache_dir = "/tmp/test_refit_identical_engine_weights" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + trt_gm = torch_trt.dynamo.compile( + exp_program, + tuple(example_inputs), + use_python_runtime=True, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + immutable_weights=False, + strip_engine_weights=True, + refit_identical_engine_weights=True, + ) + output = trt_gm(*example_inputs) + + pyt_model2 = models.resnet18(pretrained=False).eval().to("cuda") + exp_program2 = torch.export.export(pyt_model2, example_inputs) + + try: + refit_module_weights(trt_gm, exp_program) + except Exception as e: + assertions.fail( + f"Refitting the engine with the same weights failed with the following error: {e}" + ) + + try: + refit_module_weights(trt_gm, exp_program2) + assertions.fail( + "Refitting the engine with different weights should have failed but it didn't" + ) + except Exception as e: + pass diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index b52530efd1..f2bcaf7ede 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -49,7 +49,7 @@ def test_resnet18(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -89,7 +89,7 @@ def test_save(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -123,7 +123,7 @@ def test_resnet18_modify_attribute(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -164,7 +164,7 @@ def test_resnet18_modify_attribute_no_refit(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refittable": True, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -243,7 +243,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -304,7 +304,7 @@ def set_weights(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -367,7 +367,7 @@ def set_layer(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -436,7 +436,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refittable": True, + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec)