Skip to content

Commit b9ddd8b

Browse files
rasbtcarmocca
andauthored
Add precision arg for pretraining (#1353)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 43c4432 commit b9ddd8b

File tree

4 files changed

+16
-2
lines changed

4 files changed

+16
-2
lines changed

config_hub/pretrain/debug.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ model_config:
1111
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
1212
out_dir: out/pretrain/debug
1313

14+
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
15+
precision: bf16-mixed
16+
1417
# Optional path to a checkpoint directory to initialize the model from.
1518
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
1619
initial_checkpoint_dir:

config_hub/pretrain/tinyllama.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ model_config:
1111
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
1212
out_dir: out/pretrain/tiny-llama
1313

14+
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
15+
precision: bf16-mixed
16+
1417
# Optional path to a checkpoint directory to initialize the model from.
1518
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
1619
initial_checkpoint_dir:

config_hub/pretrain/tinystories.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ model_config:
2727
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
2828
out_dir: out/pretrain/stories15M
2929

30+
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
31+
precision: bf16-mixed
32+
3033
# Optional path to a checkpoint directory to initialize the model from.
3134
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
3235
initial_checkpoint_dir:

litgpt/pretrain.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
choose_logger,
3030
chunked_cross_entropy,
3131
copy_config_files,
32+
get_default_supported_precision,
3233
init_out_dir,
3334
num_parameters,
3435
parse_devices,
@@ -42,6 +43,7 @@ def setup(
4243
model_name: Optional[str] = None,
4344
model_config: Optional[Config] = None,
4445
out_dir: Path = Path("out/pretrain"),
46+
precision: Literal["bf16-true", "bf16-mixed", "32-true", None] = None,
4547
initial_checkpoint_dir: Optional[Path] = None,
4648
resume: Union[bool, Path] = False,
4749
data: Optional[DataModule] = None,
@@ -75,6 +77,7 @@ def setup(
7577
``model_config``.
7678
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
7779
/teamspace/jobs/<job-name>/share.
80+
precision: The precision to use for finetuning. Determines a compatible precision setting by default.
7881
initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
7982
Useful for continued pretraining. Mutually exclusive with ``resume``.
8083
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
@@ -96,6 +99,7 @@ def setup(
9699
available_models = "\n".join(sorted(name_to_config))
97100
raise ValueError(f"Please specify --model_name <model_name>. Available values:\n{available_models}")
98101
config = Config.from_name(model_name) if model_config is None else model_config
102+
precision = precision or get_default_supported_precision(training=True)
99103
devices = parse_devices(devices)
100104
out_dir = init_out_dir(out_dir)
101105
# in case the dataset requires the Tokenizer
@@ -109,7 +113,7 @@ def setup(
109113
strategy = FSDPStrategy(auto_wrap_policy={Block}, state_dict_type="full", sharding_strategy="HYBRID_SHARD")
110114
else:
111115
strategy = "auto"
112-
fabric = L.Fabric(devices=devices, strategy=strategy, precision="bf16-mixed", loggers=[logger])
116+
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=[logger])
113117
fabric.launch()
114118

115119
fabric.print(pprint.pformat(hparams))
@@ -169,12 +173,13 @@ def main(
169173

170174
model = torch.compile(model)
171175
model = fabric.setup(model)
176+
172177
optimizer = torch.optim.AdamW(
173178
model.parameters(),
174179
lr=train.learning_rate,
175180
weight_decay=train.weight_decay,
176181
betas=(train.beta1, train.beta2),
177-
fused=True,
182+
fused=fabric.device.type == "cuda",
178183
)
179184
optimizer = fabric.setup_optimizers(optimizer)
180185

0 commit comments

Comments
 (0)