Skip to content

Commit daad8f9

Browse files
authored
Standardize out_dir behavior (#1338)
1 parent bd8cec9 commit daad8f9

File tree

8 files changed

+38
-23
lines changed

8 files changed

+38
-23
lines changed

litgpt/finetune/adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
chunked_cross_entropy,
3030
copy_config_files,
3131
get_default_supported_precision,
32+
init_out_dir,
3233
load_checkpoint,
3334
num_parameters,
3435
parse_devices,
@@ -61,7 +62,8 @@ def setup(
6162
6263
Arguments:
6364
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
64-
out_dir: Directory in which to save checkpoints and logs.
65+
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
66+
/teamspace/jobs/<job-name>/share.
6567
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
6668
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
6769
devices: How many devices/GPUs to use.
@@ -75,6 +77,7 @@ def setup(
7577
pprint(locals())
7678
data = Alpaca() if data is None else data
7779
devices = parse_devices(devices)
80+
out_dir = init_out_dir(out_dir)
7881

7982
check_valid_checkpoint_dir(checkpoint_dir)
8083
config = Config.from_file(checkpoint_dir / "model_config.yaml")

litgpt/finetune/adapter_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
chunked_cross_entropy,
3030
copy_config_files,
3131
get_default_supported_precision,
32+
init_out_dir,
3233
load_checkpoint,
3334
num_parameters,
3435
parse_devices,
@@ -61,7 +62,8 @@ def setup(
6162
6263
Arguments:
6364
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
64-
out_dir: Directory in which to save checkpoints and logs.
65+
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
66+
/teamspace/jobs/<job-name>/share.
6567
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
6668
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
6769
devices: How many devices/GPUs to use.
@@ -75,6 +77,7 @@ def setup(
7577
pprint(locals())
7678
data = Alpaca() if data is None else data
7779
devices = parse_devices(devices)
80+
out_dir = init_out_dir(out_dir)
7881

7982
check_valid_checkpoint_dir(checkpoint_dir)
8083
config = Config.from_file(checkpoint_dir / "model_config.yaml")

litgpt/finetune/full.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
copy_config_files,
2929
get_default_supported_precision,
3030
load_checkpoint,
31+
init_out_dir,
3132
num_parameters,
3233
parse_devices,
3334
save_hyperparameters,
@@ -59,7 +60,8 @@ def setup(
5960
6061
Arguments:
6162
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
62-
out_dir: Directory in which to save checkpoints and logs.
63+
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
64+
/teamspace/jobs/<job-name>/share.
6365
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
6466
devices: How many devices/GPUs to use
6567
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
@@ -74,6 +76,7 @@ def setup(
7476
pprint(locals())
7577
data = Alpaca() if data is None else data
7678
devices = parse_devices(devices)
79+
out_dir = init_out_dir(out_dir)
7780

7881
check_valid_checkpoint_dir(checkpoint_dir)
7982
config = Config.from_file(checkpoint_dir / "model_config.yaml")

litgpt/finetune/lora.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
copy_config_files,
3232
get_default_supported_precision,
3333
load_checkpoint,
34+
init_out_dir,
3435
num_parameters,
3536
parse_devices,
3637
save_hyperparameters,
@@ -71,7 +72,8 @@ def setup(
7172
7273
Arguments:
7374
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
74-
out_dir: Directory in which to save checkpoints and logs.
75+
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
76+
/teamspace/jobs/<job-name>/share.
7577
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
7678
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
7779
devices: How many devices/GPUs to use.
@@ -94,6 +96,7 @@ def setup(
9496
pprint(locals())
9597
data = Alpaca() if data is None else data
9698
devices = parse_devices(devices)
99+
out_dir = init_out_dir(out_dir)
97100

98101
check_valid_checkpoint_dir(checkpoint_dir)
99102
config = Config.from_file(

litgpt/pretrain.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22

33
import math
4-
import os
54
import pprint
65
import time
76
from datetime import timedelta
@@ -30,6 +29,7 @@
3029
choose_logger,
3130
chunked_cross_entropy,
3231
copy_config_files,
32+
init_out_dir,
3333
num_parameters,
3434
parse_devices,
3535
reset_parameters,
@@ -404,12 +404,6 @@ def init_weights(module, std):
404404
reset_parameters(model)
405405

406406

407-
def init_out_dir(out_dir: Path) -> Path:
408-
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
409-
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
410-
return out_dir
411-
412-
413407
def save_checkpoint(fabric, state, tokenizer_dir, checkpoint_file):
414408
model = state["model"]
415409
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)

litgpt/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""Utility functions for training and inference."""
44
import inspect
55
import math
6+
import os
67
import pickle
78
import shutil
89
import sys
@@ -27,6 +28,12 @@
2728
from litgpt import GPT, Config
2829

2930

31+
def init_out_dir(out_dir: Path) -> Path:
32+
if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
33+
return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
34+
return out_dir
35+
36+
3037
def find_multiple(n: int, k: int) -> int:
3138
assert k > 0
3239
if n % k == 0:

tests/test_pretrain.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from lightning.fabric.strategies import FSDPStrategy, SingleDeviceStrategy
1414
from torch.utils.data import DataLoader
1515

16+
from test_utils import test_init_out_dir
1617
from litgpt import pretrain
1718
from litgpt.args import EvalArgs, TrainArgs
1819
from litgpt.config import Config
19-
from litgpt.pretrain import init_out_dir, initialize_weights
20+
from litgpt.pretrain import initialize_weights
2021

2122

2223
@RunIf(min_cuda_gpus=2, standalone=True)
@@ -89,17 +90,6 @@ def test_pretrain_model_name_and_config():
8990
pretrain.setup(model_name="tiny-llama-1.1b", model_config=Config(name="tiny-llama-1.1b"))
9091

9192

92-
def test_init_out_dir(tmp_path):
93-
relative_path = Path("./out")
94-
absolute_path = tmp_path / "out"
95-
assert init_out_dir(relative_path) == relative_path
96-
assert init_out_dir(absolute_path) == absolute_path
97-
98-
with mock.patch.dict(os.environ, {"LIGHTNING_ARTIFACTS_DIR": "prefix"}):
99-
assert init_out_dir(relative_path) == Path("prefix") / relative_path
100-
assert init_out_dir(absolute_path) == absolute_path
101-
102-
10393
@pytest.mark.parametrize(("strategy", "expected"), [(SingleDeviceStrategy, True), (FSDPStrategy, False)])
10494
def test_initialize_weights(strategy, expected):
10595
fabric_mock = Mock()

tests/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
copy_config_files,
3131
find_multiple,
3232
incremental_save,
33+
init_out_dir,
3334
num_parameters,
3435
parse_devices,
3536
save_hyperparameters,
@@ -294,3 +295,14 @@ def test_choose_logger(tmp_path):
294295

295296
with pytest.raises(ValueError, match="`--logger_name=foo` is not a valid option."):
296297
choose_logger("foo", out_dir=tmp_path, name="foo")
298+
299+
300+
def test_init_out_dir(tmp_path):
301+
relative_path = Path("./out")
302+
absolute_path = tmp_path / "out"
303+
assert init_out_dir(relative_path) == relative_path
304+
assert init_out_dir(absolute_path) == absolute_path
305+
306+
with mock.patch.dict(os.environ, {"LIGHTNING_ARTIFACTS_DIR": "prefix"}):
307+
assert init_out_dir(relative_path) == Path("prefix") / relative_path
308+
assert init_out_dir(absolute_path) == absolute_path

0 commit comments

Comments
 (0)