29
29
choose_logger ,
30
30
chunked_cross_entropy ,
31
31
copy_config_files ,
32
+ get_default_supported_precision ,
32
33
init_out_dir ,
33
34
num_parameters ,
34
35
parse_devices ,
@@ -42,6 +43,7 @@ def setup(
42
43
model_name : Optional [str ] = None ,
43
44
model_config : Optional [Config ] = None ,
44
45
out_dir : Path = Path ("out/pretrain" ),
46
+ precision : Literal ["bf16-true" , "bf16-mixed" , "32-true" , None ] = None ,
45
47
initial_checkpoint_dir : Optional [Path ] = None ,
46
48
resume : Union [bool , Path ] = False ,
47
49
data : Optional [DataModule ] = None ,
@@ -75,6 +77,7 @@ def setup(
75
77
``model_config``.
76
78
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
77
79
/teamspace/jobs/<job-name>/share.
80
+ precision: The precision to use for finetuning. Determines a compatible precision setting by default.
78
81
initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
79
82
Useful for continued pretraining. Mutually exclusive with ``resume``.
80
83
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
@@ -96,6 +99,7 @@ def setup(
96
99
available_models = "\n " .join (sorted (name_to_config ))
97
100
raise ValueError (f"Please specify --model_name <model_name>. Available values:\n { available_models } " )
98
101
config = Config .from_name (model_name ) if model_config is None else model_config
102
+ precision = precision or get_default_supported_precision (training = True )
99
103
devices = parse_devices (devices )
100
104
out_dir = init_out_dir (out_dir )
101
105
# in case the dataset requires the Tokenizer
@@ -109,7 +113,7 @@ def setup(
109
113
strategy = FSDPStrategy (auto_wrap_policy = {Block }, state_dict_type = "full" , sharding_strategy = "HYBRID_SHARD" )
110
114
else :
111
115
strategy = "auto"
112
- fabric = L .Fabric (devices = devices , strategy = strategy , precision = "bf16-mixed" , loggers = [logger ])
116
+ fabric = L .Fabric (devices = devices , strategy = strategy , precision = precision , loggers = [logger ])
113
117
fabric .launch ()
114
118
115
119
fabric .print (pprint .pformat (hparams ))
@@ -169,12 +173,13 @@ def main(
169
173
170
174
model = torch .compile (model )
171
175
model = fabric .setup (model )
176
+
172
177
optimizer = torch .optim .AdamW (
173
178
model .parameters (),
174
179
lr = train .learning_rate ,
175
180
weight_decay = train .weight_decay ,
176
181
betas = (train .beta1 , train .beta2 ),
177
- fused = True ,
182
+ fused = fabric . device . type == "cuda" ,
178
183
)
179
184
optimizer = fabric .setup_optimizers (optimizer )
180
185
0 commit comments