|
37 | 37 | _LAYER_EXECUTION_TRACKER = "layer_execution_tracker" |
38 | 38 | _LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading" |
39 | 39 |
|
| 40 | +# Always use memory-efficient CPU offloading to minimize RAM usage |
| 41 | + |
40 | 42 | _SUPPORTED_PYTORCH_LAYERS = ( |
41 | 43 | torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, |
42 | 44 | torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d, |
@@ -80,30 +82,61 @@ def onload_(self): |
80 | 82 | self.stream.synchronize() |
81 | 83 |
|
82 | 84 | with context: |
| 85 | + # Only transfer parameters that aren't already on the target device |
83 | 86 | for group_module in self.modules: |
84 | | - group_module.to(self.onload_device, non_blocking=self.non_blocking) |
| 87 | + if any(p.device != self.onload_device for p in group_module.parameters()): |
| 88 | + group_module.to(self.onload_device, non_blocking=self.non_blocking) |
| 89 | + |
85 | 90 | if self.parameters is not None: |
86 | 91 | for param in self.parameters: |
87 | | - param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) |
| 92 | + if param.device != self.onload_device: |
| 93 | + param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking) |
| 94 | + |
88 | 95 | if self.buffers is not None: |
89 | 96 | for buffer in self.buffers: |
90 | | - buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) |
| 97 | + if buffer.device != self.onload_device: |
| 98 | + buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking) |
91 | 99 |
|
92 | 100 | def offload_(self): |
93 | 101 | r"""Offloads the group of modules to the offload_device.""" |
94 | 102 | # Synchronize if using stream |
95 | 103 | if self.stream is not None: |
96 | 104 | torch.cuda.current_stream().synchronize() |
97 | 105 |
|
98 | | - # Use regular to() method for all cases - much simpler! |
99 | | - for group_module in self.modules: |
100 | | - group_module.to(self.offload_device, non_blocking=self.non_blocking) |
101 | | - if self.parameters is not None: |
102 | | - for param in self.parameters: |
103 | | - param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) |
104 | | - if self.buffers is not None: |
105 | | - for buffer in self.buffers: |
106 | | - buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) |
| 106 | + # For CPU offloading, use a method that preserves memory mapping benefits |
| 107 | + if self.offload_device.type == 'cpu': |
| 108 | + # Empty GPU cache before offloading to reduce memory fragmentation |
| 109 | + if torch.cuda.is_available(): |
| 110 | + torch.cuda.empty_cache() |
| 111 | + |
| 112 | + # Use to() method directly on modules |
| 113 | + for group_module in self.modules: |
| 114 | + # Don't make copies if already on CPU |
| 115 | + if any(p.device.type != 'cpu' for p in group_module.parameters()): |
| 116 | + group_module.to(self.offload_device, non_blocking=self.non_blocking) |
| 117 | + |
| 118 | + # Handle explicit parameters - avoid copies when already on CPU |
| 119 | + if self.parameters is not None: |
| 120 | + for param in self.parameters: |
| 121 | + if param.device.type != 'cpu': |
| 122 | + # Let PyTorch handle the transfer which can preserve memory mapping |
| 123 | + param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) |
| 124 | + |
| 125 | + # Handle buffers - avoid copies when already on CPU |
| 126 | + if self.buffers is not None: |
| 127 | + for buffer in self.buffers: |
| 128 | + if buffer.device.type != 'cpu': |
| 129 | + buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) |
| 130 | + else: |
| 131 | + # For non-CPU offloading, use the regular approach |
| 132 | + for group_module in self.modules: |
| 133 | + group_module.to(self.offload_device, non_blocking=self.non_blocking) |
| 134 | + if self.parameters is not None: |
| 135 | + for param in self.parameters: |
| 136 | + param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking) |
| 137 | + if self.buffers is not None: |
| 138 | + for buffer in self.buffers: |
| 139 | + buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking) |
107 | 140 |
|
108 | 141 | # After offloading, we can unpin the memory if configured to do so |
109 | 142 | # We'll keep it pinned by default for better performance |
@@ -314,7 +347,8 @@ def apply_group_offloading( |
314 | 347 | If True, offloading and onloading is done with non-blocking data transfer. |
315 | 348 | use_stream (`bool`, defaults to `False`): |
316 | 349 | If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for |
317 | | - overlapping computation and data transfer. |
| 350 | + overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used |
| 351 | + to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies. |
318 | 352 |
|
319 | 353 | Example: |
320 | 354 | ```python |
|
0 commit comments