Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions py/torch_tensorrt/_features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os
import sys
from collections import namedtuple
Expand All @@ -15,6 +16,7 @@
"dynamo_frontend",
"fx_frontend",
"refit",
"qdp_plugin",
],
)

Expand All @@ -39,14 +41,24 @@
_FX_FE_AVAIL = True
_REFIT_AVAIL = True

if importlib.util.find_spec("tensorrt.plugin"):
_QDP_PLUGIN_AVAIL = True
else:
_QDP_PLUGIN_AVAIL = False

ENABLED_FEATURES = FeatureSet(
_TS_FE_AVAIL, _TORCHTRT_RT_AVAIL, _DYNAMO_FE_AVAIL, _FX_FE_AVAIL, _REFIT_AVAIL
_TS_FE_AVAIL,
_TORCHTRT_RT_AVAIL,
_DYNAMO_FE_AVAIL,
_FX_FE_AVAIL,
_REFIT_AVAIL,
_QDP_PLUGIN_AVAIL,
)


def _enabled_features_str() -> str:
enabled = lambda x: "ENABLED" if x else "DISABLED"
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n" # type: ignore[no-untyped-call]
out_str: str = f"Enabled Features:\n - Dynamo Frontend: {enabled(_DYNAMO_FE_AVAIL)}\n - Torch-TensorRT Runtime: {enabled(_TORCHTRT_RT_AVAIL)}\n - FX Frontend: {enabled(_FX_FE_AVAIL)}\n - TorchScript Frontend: {enabled(_TS_FE_AVAIL)}\n - Refit: {enabled(_REFIT_AVAIL)}\n - QDP Plugin: {enabled(_QDP_PLUGIN_AVAIL)}\n" # type: ignore[no-untyped-call]
return out_str


Expand All @@ -64,6 +76,20 @@ def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
return wrapper


def needs_qdp_plugin(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.qdp_plugin:
return f(*args, **kwargs)
else:

def not_implemented(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
raise NotImplementedError("QDP Plugin is not available")

return not_implemented(*args, **kwargs)

return wrapper


def needs_refit(f: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: List[Any], **kwargs: Dict[str, Any]) -> Any:
if ENABLED_FEATURES.refit:
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from types import FunctionType
from typing import Any, Callable, Tuple

import tensorrt.plugin as trtp
import torch
from sympy import lambdify
from torch._dynamo.source import LocalSource
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch_tensorrt._features import needs_qdp_plugin

_LOGGER: logging.Logger = logging.getLogger(__name__)

Expand All @@ -28,6 +28,13 @@ def mksym(


def _generate_plugin(plugin_name: str) -> None:
try:
import tensorrt.plugin as trtp
except ImportError as e:
raise RuntimeError(
"Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models"
)

namespace, name = plugin_name.split("::")

# retrieve the corresponding torch operation using the passed in string
Expand Down Expand Up @@ -211,6 +218,7 @@ def _generic_plugin_impl(
trtp.impl(plugin_name)(plugin_impl)


@needs_qdp_plugin
def generate_plugin(plugin_name: str) -> None:
"""
Generate the Plugin using external kernels and TensorRT Quick Deployable Plugin APIs.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@

import numpy as np
import tensorrt as trt

# Seems like a bug in TensorRT
import tensorrt.plugin as trtp
import torch
from tensorrt.plugin._lib import QDP_REGISTRY
from torch.fx.node import Argument, Node, Target
from torch_tensorrt._features import needs_qdp_plugin
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand All @@ -32,6 +29,15 @@ def _generate_plugin_converter(
supports_dynamic_shapes: bool = False,
requires_output_allocator: bool = False,
) -> DynamoConverterImplSignature:
try:
import tensorrt.plugin as trtp

except ImportError as e:
raise RuntimeError(
"Unable to import TensorRT plugin. Please install TensorRT plugin library (https://github.com/NVIDIA/TensorRT-plugin-library?tab=readme-ov-file#installation) to add support for compiling quantized models"
)
from tensorrt.plugin._lib import QDP_REGISTRY

torch_target = getattr(getattr(torch.ops, namespace), op_name)
overload_str = overload if overload else ""
overload_name = overload_str if overload else "default"
Expand Down Expand Up @@ -101,6 +107,7 @@ def custom_kernel_converter(
return custom_kernel_converter


@needs_qdp_plugin
def generate_plugin_converter(
plugin_id: str,
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
Expand Down
Loading