Commit 0fffc72
authored
[fix] optimise
### Context
On attempting to finetune `pi0_fast_base` on some DROID data\[1\], my
swap-space kept getting filled up and crashing the process even before
the training began. On stepping through the code, I realised the culprit
was an inefficient method, `_merge_params`\[2\] that was allocating far
more memory than it needed to.
\[1] The training config I used is:
```python
TrainConfig(
name="pi0_fast_droid_finetune_low_mem",
model=pi0_fast.Pi0FASTConfig(
action_dim=8,
action_horizon=16,
max_token_len=180,
paligemma_variant="gemma_2b_lora",
),
data=RLDSDroidDataConfig(
repo_id="droid",
# Set this to the path to your DROID RLDS dataset (the parent directory of the `droid` directory).
rlds_data_dir="data",
action_space=droid_rlds_dataset.DroidActionSpace.JOINT_POSITION,
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_fast_base/params"),
num_train_steps=10, # 100k steps should be sufficient, takes ~2 days on 8x H100s
batch_size=1,
save_interval=10,
num_workers=0, # Important: RLDS DataLoader requires num_workers=0, handles multi-processing internally
freeze_filter=pi0_fast.Pi0FASTConfig(
action_dim=8, action_horizon=16, max_token_len=180, paligemma_variant="gemma_2b_lora"
).get_freeze_filter(),
# Turn off EMA for LORA finetuning.
ema_decay=None,
),
```
\[2] The `_merge_params` method invoked from `CheckpointWeightLoader` is
as below:
```
1 def merge_params(loaded_params: at.Params, params: at.Params, *, missing_regex: str) -> at.Params:
2 """Merges the loaded parameters with the reference parameters.
3 Args:
4 loaded_params: The parameters to merge.
5 params: The reference parameters.
6 missing_regex: A regex pattern for all missing keys that should be merged from the reference parameters.
7 Returns:
8 A new dictionary with the merged parameters.
9 """
10 flat_ref = flax.traverse_util.flatten_dict(params, sep="/")
11 flat_loaded = flax.traverse_util.flatten_dict(loaded_params, sep="/")
12 # First, take all weights that are a subset of the reference weights.
13 result = {}
14 for k, v in flat_loaded.items():
15 if k in flat_ref:
16 result[k] = v.astype(flat_ref[k].dtype)
17 # Then, merge any missing weights as defined by the missing regex.
18 pattern = re.compile(missing_regex)
19 for k in {k for k in flat_ref if pattern.fullmatch(k)}:
20 if k not in result:
21 result[k] = flat_ref[k]
22 return flax.traverse_util.unflatten_dict(result, sep="/")
```
### The Fix
| **3** | **4** | **5** |
|:-----:|:-----:|:-----:|
|

|

|

|
#### Reducing Redundant Allocations
In the initial memory profile (\[3]), we see a number of jumps in
allocated memory reported by `tracemalloc`, which is likely because of
line 16, where we always create a copy of the value `v` with the
intention of making it have the same type as `flat_ref[k]`. This
however, is un-necessary. If `v` already has the same type as
`flat_ref[k]`, then we need not create a new copy at all.
Thus, I replaced line 16 with the following:
```python
result[k] = v.astype(flat_ref[k].dtype) if v.dtype != flat_ref[k].dtype else v
```
As seen in the second memory graph (\[4]), this got rid of the numerous
allocations reported by `tracemalloc` and the overall memory allocated
over the course of the method runtime.
#### Reducing RSS Memory
Secondly, we reduce the peak memory occupied by the process by freeing
`flat_loaded` the moment it is no longer required. The results of this
can be seen in the third memory graph (\[5]), which sports a peak RSS of
about 22 GB, as opposed to the previous variants which occupied about 26
GB.
After the above two changes, I was successfully able to finetune the
base model [on
`droid_100`](https://huggingface.co/datasets/lerobot/droid_100)._merge_params to prevent CPU overload (#616)1 file changed
+3
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
91 | 91 | | |
92 | 92 | | |
93 | 93 | | |
94 | | - | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
95 | 97 | | |
96 | 98 | | |
97 | 99 | | |
| |||
0 commit comments