|
7 | 7 | import logging |
8 | 8 | import math |
9 | 9 | import os |
| 10 | +import shutil |
10 | 11 | import time |
11 | 12 | from collections.abc import Mapping |
12 | 13 | from dataclasses import dataclass, field, fields |
|
39 | 40 | from forge.data.utils import batch_to_device |
40 | 41 |
|
41 | 42 | logger = logging.getLogger(__name__) |
42 | | -logger.setLevel(logging.INFO) |
| 43 | +logger.setLevel(logging.DEBUG) |
| 44 | + |
| 45 | + |
| 46 | +def cleanup_old_weight_versions( |
| 47 | + state_dict_key: str, |
| 48 | + delim: str, |
| 49 | + current_policy_version: int, |
| 50 | +) -> None: |
| 51 | + """Delete old weight versions, keeping only current and N-1 versions. |
| 52 | +
|
| 53 | + TODO - issues/194: provide a more robust way to handle eviction. |
| 54 | +
|
| 55 | + Args: |
| 56 | + state_dict_key: The base key for state dict storage |
| 57 | + delim: The delimiter used between key and version |
| 58 | + current_policy_version: The current policy version to keep |
| 59 | + """ |
| 60 | + if current_policy_version <= 1: |
| 61 | + return # No cleanup needed for versions 0 or 1 |
| 62 | + |
| 63 | + prefix = f"{state_dict_key}{delim}" |
| 64 | + current_weights = f"{prefix}{current_policy_version}" |
| 65 | + previous_weights = f"{prefix}{current_policy_version - 1}" |
| 66 | + |
| 67 | + # Find all weight directories that match our pattern |
| 68 | + parent_dir = os.path.dirname(prefix) or "." |
| 69 | + if os.path.exists(parent_dir): |
| 70 | + for item in os.listdir(parent_dir): |
| 71 | + item_path = os.path.join(parent_dir, item) |
| 72 | + if ( |
| 73 | + item.startswith(os.path.basename(prefix)) |
| 74 | + and item != os.path.basename(current_weights) |
| 75 | + and item != os.path.basename(previous_weights) |
| 76 | + and os.path.isdir(item_path) |
| 77 | + ): |
| 78 | + try: |
| 79 | + shutil.rmtree(item_path, ignore_errors=True) |
| 80 | + logger.debug(f"Removed old weights at {item_path}") |
| 81 | + except OSError as e: |
| 82 | + logger.debug(f"Error deleting {item_path}: {e}") |
43 | 83 |
|
44 | 84 |
|
45 | 85 | @dataclass |
@@ -67,6 +107,7 @@ def __post_init__(self): |
67 | 107 | in monarch for now. |
68 | 108 |
|
69 | 109 | """ |
| 110 | + super().__init__() |
70 | 111 | # Instantiate dict fields |
71 | 112 | for f in fields(self): |
72 | 113 | attr = getattr(self, f.name) |
@@ -223,13 +264,26 @@ async def push_weights(self, policy_version: int) -> None: |
223 | 264 | ) |
224 | 265 | hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) |
225 | 266 | # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed |
226 | | - vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28) |
| 267 | + vllm_ready_hf_sd = _qwen3_hf_to_vllm( |
| 268 | + sd=hf_state_dict, num_layers=self.engine.model_args.n_layers |
| 269 | + ) |
227 | 270 |
|
228 | 271 | key = f"{self.state_dict_key}{DELIM}{policy_version}" |
229 | 272 | start_time = time.time() |
230 | 273 | if self.use_dcp: |
| 274 | + |
| 275 | + # TODO - DCP should probably be being saved to NFS explicitly? |
| 276 | + # Right now it will only save everything locally |
231 | 277 | metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd) |
232 | 278 | await ts.put(key, metadata) |
| 279 | + |
| 280 | + # Delete old weight versions if they exist |
| 281 | + if self.rank == 0: |
| 282 | + cleanup_old_weight_versions( |
| 283 | + state_dict_key=self.state_dict_key, |
| 284 | + delim=DELIM, |
| 285 | + current_policy_version=policy_version, |
| 286 | + ) |
233 | 287 | else: |
234 | 288 | await ts.put_state_dict(vllm_ready_hf_sd, key) |
235 | 289 | end_time = time.time() |
|
0 commit comments