|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from contextlib import nullcontext, contextmanager |
| 15 | +from contextlib import contextmanager, nullcontext |
16 | 16 | from typing import Dict, List, Optional, Set, Tuple |
17 | 17 |
|
18 | 18 | import torch |
@@ -102,9 +102,7 @@ def onload_(self): |
102 | 102 | with self._pinned_memory_tensors() as pinned_memory: |
103 | 103 | for module in self.modules: |
104 | 104 | for param in module.parameters(): |
105 | | - param.data = pinned_memory[param].to( |
106 | | - self.onload_device, non_blocking=self.non_blocking |
107 | | - ) |
| 105 | + param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking) |
108 | 106 | else: |
109 | 107 | for group_module in self.modules: |
110 | 108 | for param in group_module.parameters(): |
@@ -392,7 +390,9 @@ def apply_group_offloading( |
392 | 390 | module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage |
393 | 391 | ) |
394 | 392 | elif offload_type == "leaf_level": |
395 | | - _apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage) |
| 393 | + _apply_group_offloading_leaf_level( |
| 394 | + module, offload_device, onload_device, non_blocking, stream, low_cpu_mem_usage |
| 395 | + ) |
396 | 396 | else: |
397 | 397 | raise ValueError(f"Unsupported offload_type: {offload_type}") |
398 | 398 |
|
@@ -425,11 +425,6 @@ def _apply_group_offloading_block_level( |
425 | 425 | for overlapping computation and data transfer. |
426 | 426 | """ |
427 | 427 |
|
428 | | - # Create a pinned CPU parameter dict for async data transfer if streams are to be used |
429 | | - cpu_param_dict = None |
430 | | - if stream is not None: |
431 | | - cpu_param_dict = _get_pinned_cpu_param_dict(module) |
432 | | - |
433 | 428 | # Create module groups for ModuleList and Sequential blocks |
434 | 429 | modules_with_group_offloading = set() |
435 | 430 | unmatched_modules = [] |
@@ -522,11 +517,6 @@ def _apply_group_offloading_leaf_level( |
522 | 517 | for overlapping computation and data transfer. |
523 | 518 | """ |
524 | 519 |
|
525 | | - # Create a pinned CPU parameter dict for async data transfer if streams are to be used |
526 | | - cpu_param_dict = None |
527 | | - if stream is not None: |
528 | | - cpu_param_dict = _get_pinned_cpu_param_dict(module) |
529 | | - |
530 | 520 | # Create module groups for leaf modules and apply group offloading hooks |
531 | 521 | modules_with_group_offloading = set() |
532 | 522 | for name, submodule in module.named_modules(): |
@@ -641,19 +631,15 @@ def _apply_lazy_group_offloading_hook( |
641 | 631 | registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING) |
642 | 632 |
|
643 | 633 |
|
644 | | -def _get_cpu_param_dict(module: torch.nn.Module, low_cpu_mem_usage: bool = False) -> Dict[torch.nn.Parameter, torch.Tensor]: |
| 634 | +def _get_cpu_param_dict( |
| 635 | + module: torch.nn.Module, low_cpu_mem_usage: bool = False |
| 636 | +) -> Dict[torch.nn.Parameter, torch.Tensor]: |
645 | 637 | cpu_param_dict = {} |
646 | 638 | for param in module.parameters(): |
647 | | - if low_cpu_mem_usage: |
648 | | - cpu_param_dict[param] = param.data.cpu() |
649 | | - else: |
650 | | - cpu_param_dict[param] = param.data.cpu().pin_memory() |
| 639 | + cpu_param_dict[param] = param.data.cpu() if low_cpu_mem_usage else param.data.cpu().pin_memory() |
651 | 640 |
|
652 | 641 | for buffer in module.buffers(): |
653 | | - if low_cpu_mem_usage: |
654 | | - cpu_param_dict[buffer] = buffer.data.cpu() |
655 | | - else: |
656 | | - cpu_param_dict[buffer] = buffer.data.cpu().pin_memory() |
| 642 | + cpu_param_dict[buffer] = buffer.data.cpu() if low_cpu_mem_usage else buffer.data.cpu().pin_memory() |
657 | 643 |
|
658 | 644 | return cpu_param_dict |
659 | 645 |
|
|
0 commit comments