Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ jobs:
python -m pip install -r requirements.txt
cd dynamo
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml test_modelopt_models.py
popd

tests-py-dynamo-serde:
Expand Down
122 changes: 62 additions & 60 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorrt as trt
import torch
from torch.export import ExportedProgram
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import partitioning
Expand Down Expand Up @@ -144,71 +145,72 @@ def _refit_single_trt_engine_with_gm(
Refit a TensorRT Engine in place
"""

refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)
with unset_fake_temporarily():
refitted = set()
torch_device = get_model_device(new_gm)
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

if weight_name_map:
# Get the refitting mapping
trt_wt_location = (
trt.TensorLocation.DEVICE
if torch_device.type == "cuda"
else trt.TensorLocation.HOST
)

constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
constant_mapping: dict[str, Any] = weight_name_map.pop(
"constant_mapping", {}
) # type: ignore
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
constant_mapping_with_type = {}

for constant_name, val in constant_mapping.items():
np_weight_type = val.dtype
val_tensor = torch.from_numpy(val).cuda()
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
constant_mapping_with_type[constant_name] = (
val_tensor.clone().reshape(-1).contiguous().to(torch_dtype),
trt_dtype,
)

mapping.update(constant_mapping_with_type)
mapping.update(constant_mapping_with_type)

for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"
for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")
else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
raise AssertionError("Refitting failed.")


def refit_module_weights(
Expand Down
40 changes: 20 additions & 20 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tensorrt as trt
import torch
import torch.fx
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from torch.utils._python_dispatch import _disable_current_modes
Expand All @@ -41,6 +42,7 @@
get_node_io,
get_node_name,
get_trt_tensor,
to_torch,
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, get_model_device, to_torch_device
from torch_tensorrt.fx.observer import Observer
Expand Down Expand Up @@ -408,27 +410,29 @@ def find_weight(
np_map: the map from weight name to np values in INetworkDefinition
state_dict: state of the graph module
"""
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""
with unset_fake_temporarily():
network_weight = torch.from_numpy(np_map[weight_name]).to(device)
for sd_w_name, sd_weight in state_dict.items():
if TRTInterpreter.check_weight_equal(sd_weight, network_weight, device):
del state_dict[sd_w_name]
return sd_w_name
return ""

@staticmethod
def check_weight_equal(
sd_weight: torch.tensor,
network_weight: Union[torch.Tensor, np.ndarray],
device: torch.device,
) -> Any:
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)
with unset_fake_temporarily():
if not isinstance(network_weight, torch.Tensor):
network_weight = torch.from_numpy(network_weight).to(device)
try:
return sd_weight.shape == network_weight.shape and torch.all(
torch.abs(sd_weight - network_weight) < 0.01
)
except Exception:
return torch.all(sd_weight == network_weight)

def _save_weight_mapping(self) -> None:
"""
Expand Down Expand Up @@ -887,19 +891,15 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
return converter(self.ctx, target, args, kwargs, self._cur_node_name)

def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
with _disable_current_modes():
from torch_tensorrt.dynamo.conversion.converter_utils import to_numpy

with _disable_current_modes(), unset_fake_temporarily():
frozen_attr = self.fetch_attr(target)

if isinstance(frozen_attr, torch.nn.Parameter):
constant_tensor = frozen_attr.data
else:
constant_tensor = frozen_attr

network_constant = to_numpy(constant_tensor)

return network_constant
return to_torch(constant_tensor)

def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
assert isinstance(target, str)
Expand Down
103 changes: 92 additions & 11 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tensorrt as trt
import torch
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch.fx.node import Argument, Target
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt import _enums
Expand Down Expand Up @@ -340,17 +341,47 @@ def create_constant(
Returns:
A TensorRT ITensor that represents the given value.
"""
shape = (1,)
# Rank 0 constant is required in IFillLayer inputs.
if min_rank == 0:
shape = trt.Dims()
numpy_value = to_numpy(value, dtype)
constant = ctx.net.add_constant(
shape if isinstance(value, (int, float, bool)) else value.shape,
numpy_value.copy() if isinstance(numpy_value, np.ndarray) else numpy_value,
)
constant.name = name
return constant.get_output(0)
with unset_fake_temporarily():

torch_value = to_torch(value, dtype)
if torch_value.dtype == torch.float64:
raise ValueError(
"TensorRT does not support float64 (double) precision. To resolve this, please set truncate_double=True in your compilation settings and re-run the model."
)
# Rank 0 constant is required in IFillLayer inputs.
if min_rank == 0 and isinstance(value, (int, float, bool)):
shape = trt.Dims()
elif list(torch_value.shape) == []:
shape = trt.Dims()
else:
shape = list(torch_value.shape)

if torch_value is not None:
if torch_value.dtype == torch.bfloat16:
torch_value_fp32 = torch_value.to(torch.float32)
numpy_value = torch_value_fp32.numpy()
else:
numpy_value = torch_value.numpy()

constant = ctx.net.add_constant(
shape,
numpy_value,
)
constant.name = name

if torch_value.dtype == torch.bfloat16:
return cast_trt_tensor(
ctx,
constant.get_output(0),
trt.DataType.BF16,
name + "_bf16_cast",
)

return constant.get_output(0)
else:
raise ValueError(
f"Cannot convert tensor '{name}' to a TensorRT constant because its value is None."
)


def get_trt_tensor(
Expand Down Expand Up @@ -564,6 +595,9 @@ def to_numpy(
value = value.dequantize()
elif value.dtype == torch.bfloat16:
# TODO: Remove when numpy has a BF16 type
_LOGGER.warning(
"Requested a conversion of bfloat16 tensor from torch to numpy which isn't supported. Casting this tensor to FP32 precision currently. Please use to_torch() API for better data representation",
)
value = value.to(torch.float)

output = value.cpu().detach().contiguous().numpy()
Expand All @@ -589,6 +623,53 @@ def to_numpy(
)


def to_torch(
value: Optional[Union[torch.Tensor, np.ndarray, int, float, bool]],
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType, _enums.dtype]] = None,
) -> Optional[torch.Tensor]:
"""
Convert a Numpy array, or scalar to a PyTorch tensor and move it to CPU
Args:
value (Optional[Union[torch.Tensor, np.ndarray, int, float, bool]]):
A PyTorch tensor, Numpy array, int, float, or bool
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
If a dtype is given, we will convert the type of the given `value` to this dtype.
Returns:
A PyTorch tensor or None, if the input was None.
"""

cpu_device = torch.device("cpu")
torch_dtype = (
_enums.dtype._from(dtype).to(torch.dtype, use_default=True) if dtype else None
)

with unset_fake_temporarily():
if value is None:
return None

elif isinstance(value, torch.Tensor):
output = value.to(cpu_device).contiguous()

elif isinstance(value, np.ndarray):
output = torch.from_numpy(value).to(cpu_device).contiguous()

elif isinstance(value, int):
output = torch.tensor([value], device=cpu_device, dtype=torch.int32)

elif isinstance(value, float):
output = torch.tensor([value], device=cpu_device, dtype=torch.float32)

elif isinstance(value, bool):
output = torch.tensor([value], device=cpu_device, dtype=torch.bool)

else:
raise AssertionError(
f"to_torch can only be called on None, bool, int, float, np.ndarray, or torch.Tensor, got an object of type: {type(value)}"
)

return output.to(torch_dtype) if torch_dtype else output


def flatten_dims(
input: Sequence[Union[TRTTensor, torch.Tensor, np.ndarray]],
start_dim: int,
Expand Down
Loading
Loading