Skip to content

Commit 2540824

Browse files
committed
Delete one copy
1 parent c286767 commit 2540824

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,8 @@ def compile(
693693

694694
# Move the weights in the state_dict to CPU
695695
if offload_module_to_cpu:
696-
deallocate_module(exported_program.module(), delete_module=False)
696+
deallocate_module(gm, delete_module=False)
697+
# deallocate_module(exported_program.module(), delete_module=False)
697698
logger.info(
698699
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
699700
)

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None:
512512
_LOGGER.info("Building weight name mapping...")
513513
# Stage 1: Name mapping
514514
torch_device = to_torch_device(self.compilation_settings.device)
515-
self.module.to(torch_device)
516-
sd = self.module.state_dict()
515+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
517516
weight_name_map: dict[str, Any] = {}
518517
weight_refit_map = self.ctx.weight_refit_map
519518
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ def constant_fold(
3737
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
3838
for node, constant in cf.node_replacements.items():
3939
replace_node_with_constant(
40-
gm, node, torch.nn.Parameter(constant, requires_grad=False)
40+
gm,
41+
node,
42+
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
4143
)
4244

4345
erased_params = []

0 commit comments

Comments
 (0)