|
1 |
| -import gc |
2 | 1 | import os
|
3 | 2 | import random
|
4 | 3 | from pathlib import Path
|
@@ -97,13 +96,22 @@ def save_unsharded_model(
|
97 | 96 | else:
|
98 | 97 | save_state_dict(state_dict, checkpoint, use_safetensors)
|
99 | 98 |
|
100 |
| - def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): |
| 99 | + def load_unsharded_model( |
| 100 | + self, |
| 101 | + model: GeminiDDP, |
| 102 | + checkpoint: str, |
| 103 | + strict: bool = True, |
| 104 | + low_cpu_mem_mode: bool = True, |
| 105 | + num_threads: int = 1, |
| 106 | + ): |
101 | 107 | """
|
102 | 108 | Load model from checkpoint with automatic unwrapping.
|
103 | 109 | The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
104 | 110 | """
|
105 | 111 | assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
106 |
| - super().load_unsharded_model(model, checkpoint, strict=strict) |
| 112 | + super().load_unsharded_model( |
| 113 | + model, checkpoint, strict=strict, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads |
| 114 | + ) |
107 | 115 |
|
108 | 116 | def save_unsharded_optimizer(
|
109 | 117 | self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
@@ -131,13 +139,17 @@ def save_unsharded_optimizer(
|
131 | 139 | else:
|
132 | 140 | save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
133 | 141 |
|
134 |
| - def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): |
| 142 | + def load_unsharded_optimizer( |
| 143 | + self, optimizer: GeminiOptimizer, checkpoint: str, low_cpu_mem_mode: bool = True, num_threads: int = 1 |
| 144 | + ): |
135 | 145 | """
|
136 | 146 | Loading unsharded optimizer from checkpoint file.
|
137 | 147 | For each process, only loading optimizer states of parameters it controls.
|
138 | 148 | """
|
139 | 149 | assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
140 |
| - super().load_unsharded_optimizer(optimizer, checkpoint) |
| 150 | + super().load_unsharded_optimizer( |
| 151 | + optimizer, checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads |
| 152 | + ) |
141 | 153 |
|
142 | 154 | def save_sharded_model(
|
143 | 155 | self,
|
@@ -206,13 +218,27 @@ def save_sharded_model(
|
206 | 218 | )
|
207 | 219 |
|
208 | 220 | def load_sharded_model(
|
209 |
| - self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False |
| 221 | + self, |
| 222 | + model: GeminiDDP, |
| 223 | + checkpoint_index_file: Path, |
| 224 | + strict: bool = False, |
| 225 | + use_safetensors: bool = False, |
| 226 | + low_cpu_mem_mode: bool = True, |
| 227 | + num_threads: int = 1, |
210 | 228 | ):
|
211 | 229 | """
|
212 | 230 | Load shard model, load model from multiple files.
|
213 | 231 | """
|
214 | 232 | assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
215 |
| - return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) |
| 233 | + return super().load_sharded_model( |
| 234 | + model, |
| 235 | + checkpoint_index_file, |
| 236 | + strict, |
| 237 | + use_safetensors, |
| 238 | + load_sub_module=False, |
| 239 | + low_cpu_mem_mode=low_cpu_mem_mode, |
| 240 | + num_threads=num_threads, |
| 241 | + ) |
216 | 242 |
|
217 | 243 | def save_sharded_optimizer(
|
218 | 244 | self,
|
@@ -289,7 +315,14 @@ def save_sharded_optimizer(
|
289 | 315 | ranks=[0],
|
290 | 316 | )
|
291 | 317 |
|
292 |
| - def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str): |
| 318 | + def load_sharded_optimizer( |
| 319 | + self, |
| 320 | + optimizer: GeminiOptimizer, |
| 321 | + checkpoint_index_file: Path, |
| 322 | + prefix: str, |
| 323 | + low_cpu_mem_mode: bool = True, |
| 324 | + num_threads: int = 1, |
| 325 | + ): |
293 | 326 | """
|
294 | 327 | Loading sharded optimizer from checkpoint folder, with index file given.
|
295 | 328 | For each process, only loading optimizer states of parameters it controls.
|
@@ -322,9 +355,9 @@ def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_fi
|
322 | 355 | state_dict_shard = load_flat(shard_file)
|
323 | 356 | else:
|
324 | 357 | state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
| 358 | + if not low_cpu_mem_mode: |
| 359 | + state_dict_shard = create_pinned_state_dict(state_dict_shard, empty=False, num_threads=num_threads) |
325 | 360 | optimizer.load_param_states(state_dict_shard)
|
326 |
| - del state_dict_shard |
327 |
| - gc.collect() |
328 | 361 |
|
329 | 362 | optimizer.optimizer_loading_epilogue()
|
330 | 363 |
|
|
0 commit comments