Skip to content

Commit 0fffc72

Browse files
authored
[fix] optimise _merge_params to prevent CPU overload (#616)
### Context On attempting to finetune `pi0_fast_base` on some DROID data\[1\], my swap-space kept getting filled up and crashing the process even before the training began. On stepping through the code, I realised the culprit was an inefficient method, `_merge_params`\[2\] that was allocating far more memory than it needed to. \[1] The training config I used is: ```python TrainConfig( name="pi0_fast_droid_finetune_low_mem", model=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=16, max_token_len=180, paligemma_variant="gemma_2b_lora", ), data=RLDSDroidDataConfig( repo_id="droid", # Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory). rlds_data_dir="data", action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION, ), weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"), num_train_steps=10, # 100k steps should be sufficient, takes ~2 days on 8x H100s batch_size=1, save_interval=10, num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally freeze_filter=pi0_fast.Pi0FASTConfig( action_dim=8, action_horizon=16, max_token_len=180, paligemma_variant="gemma_2b_lora" ).get_freeze_filter(), # Turn off EMA for LORA finetuning. ema_decay=None, ), ``` \[2] The `_merge_params` method invoked from `CheckpointWeightLoader` is as below: ``` 1 def merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params: 2 """Merges the loaded parameters with the reference parameters. 3 Args: 4 loaded_params: The parameters to merge. 5 params: The reference parameters. 6 missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters. 7 Returns: 8 A new dictionary with the merged parameters. 9 """ 10 flat_ref = flax.traverse_util.flatten_dict(params, sep="/") 11 flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/") 12 # First, take all weights that are a subset of the reference weights. 13 result = {} 14 for k, v in flat_loaded.items(): 15 if k in flat_ref: 16 result[k] = v.astype(flat_ref[k].dtype) 17 # Then, merge any missing weights as defined by the missing regex. 18 pattern = re.compile(missing_regex) 19 for k in {k for k in flat_ref if pattern.fullmatch(k)}: 20 if k not in result: 21 result[k] = flat_ref[k] 22 return flax.traverse_util.unflatten_dict(result, sep="/") ``` ### The Fix | **3** | **4** | **5** | |:-----:|:-----:|:-----:| | ![original](https://github.com/user-attachments/assets/a3da3aca-2591-4afb-9ae1-73a529711fb7) | ![remove_redundant](https://github.com/user-attachments/assets/6e2bd17e-963c-4521-a117-8e9bd1489347) | ![free_traverse_utils](https://github.com/user-attachments/assets/4b58e189-d3a7-41d4-9813-bc19fccf31b6) | #### Reducing Redundant Allocations In the initial memory profile (\[3]), we see a number of jumps in allocated memory reported by `tracemalloc`, which is likely because of line 16, where we always create a copy of the value `v` with the intention of making it have the same type as `flat_ref[k]`. This however, is un-necessary. If `v` already has the same type as `flat_ref[k]`, then we need not create a new copy at all. Thus, I replaced line 16 with the following: ```python result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v ``` As seen in the second memory graph (\[4]), this got rid of the numerous allocations reported by `tracemalloc` and the overall memory allocated over the course of the method runtime. #### Reducing RSS Memory Secondly, we reduce the peak memory occupied by the process by freeing `flat_loaded` the moment it is no longer required. The results of this can be seen in the third memory graph (\[5]), which sports a peak RSS of about 22 GB, as opposed to the previous variants which occupied about 26 GB. After the above two changes, I was successfully able to finetune the base model [on `droid_100`](https://huggingface.co/datasets/lerobot/droid_100).
2 parents 35e10d2 + 237d886 commit 0fffc72

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/openpi/training/weight_loaders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,9 @@ def _merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex:
9191
result = {}
9292
for k, v in flat_loaded.items():
9393
if k in flat_ref:
94-
result[k] = v.astype(flat_ref[k].dtype)
94+
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
95+
96+
flat_loaded.clear()
9597

9698
# Then, merge any missing weights as defined by the missing regex.
9799
pattern = re.compile(missing_regex)

0 commit comments

Comments
 (0)