Skip to content

Commit 0cafdb1

Browse files
committed
manual gc in llama3
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent e0b8624 commit 0cafdb1

File tree

3 files changed

+15
-4
lines changed

3 files changed

+15
-4
lines changed

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ logger:
7474

7575
profiler:
7676
enabled: false
77+
gc_interval: 1_000 # Run garbage collection every 1000 steps
7778
schedule:
7879
wait: 10
7980
warmup: 10

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import logging
1718
import time
1819
from pathlib import Path
@@ -71,6 +72,11 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
7172
self.previous_step_time = time.perf_counter()
7273
self._profiler = None
7374

75+
# Manually control garbage collection for cleaner profiling.
76+
self._gc_interval = args.profiler.gc_interval
77+
gc.disable()
78+
self._run_garbage_collection()
79+
7480
if self._dist_config.is_main_process():
7581
# Log the entire args object to wandb for experiment tracking and reproducibility.
7682
self._wandb_run = wandb.init(**args.wandb, config=self._run_config)
@@ -134,6 +140,9 @@ def log_step(
134140
if self._dist_config.local_rank == 0:
135141
logger.info(", ".join([f"{k.split('/')[1]}: {v:.3g}" for k, v in metrics.items()]))
136142

143+
if (step + 1) % self._gc_interval == 0:
144+
self._run_garbage_collection()
145+
137146
def finish(self):
138147
"""Finish the logger and close the progress bar."""
139148
if self._profiler is not None:
@@ -145,6 +154,11 @@ def finish(self):
145154
wandb.finish()
146155
self._progress_bar.close()
147156

157+
def _run_garbage_collection(self):
158+
"""Run garbage collection."""
159+
gc.collect()
160+
torch.cuda.empty_cache()
161+
148162

149163
def setup_profiler(args: DictConfig, wandb_run: wandb.Run):
150164
"""Setup a basic torch profiler for the experiment.

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import gc
1716
import logging
1817
from contextlib import nullcontext
1918
from pathlib import Path
@@ -128,9 +127,6 @@ def main(args: DictConfig) -> float | None:
128127

129128
perf_logger = PerfLogger(dist_config, args)
130129

131-
gc.collect()
132-
torch.cuda.empty_cache()
133-
134130
# Training loop
135131
logger.info(f"Starting training loop from step {start_step} to {args.num_train_steps}")
136132
step = start_step

0 commit comments

Comments
 (0)