6565
6666@run .cli .entrypoint (namespace = "llm" )
6767def train (
68- model : pl .LightningModule ,
68+ model : Union [ pl .LightningModule , AnyPath ] ,
6969 data : pl .LightningDataModule ,
7070 trainer : Trainer ,
7171 log : Annotated [Optional [NeMoLogger ], run .Config [NeMoLogger ]] = None ,
@@ -79,7 +79,7 @@ def train(
7979 Trains a model using the specified data and trainer, with optional tokenizer, source, and export.
8080
8181 Args:
82- model (pl.LightningModule): The model to be trained.
82+ model (Union[ pl.LightningModule, AnyPath] ): The model to be trained or a path to the NeMo 2 checkpoint .
8383 data (pl.LightningDataModule): The data module containing training data.
8484 trainer (Trainer): The trainer instance configured with a MegatronStrategy.
8585 log (NeMoLogger): A nemologger instance.
@@ -106,6 +106,8 @@ def train(
106106 >>> llm.train(model, data, trainer, tokenizer="data")
107107 PosixPath('/path/to/log_dir')
108108 """
109+ model = _load_model_from_path (model )
110+
109111 # [ModelOpt]: If modelopt_state exists, overwrite transformer_layer_spec to modelopt spec
110112 if resume :
111113 if resume .restore_config and resume .restore_config .path :
@@ -131,7 +133,7 @@ def train(
131133
132134@run .cli .entrypoint (namespace = "llm" )
133135def pretrain (
134- model : pl .LightningModule ,
136+ model : Union [ pl .LightningModule , AnyPath ] ,
135137 data : pl .LightningDataModule ,
136138 trainer : Trainer ,
137139 log : Annotated [Optional [NeMoLogger ], run .Config [NeMoLogger ]] = None ,
@@ -145,7 +147,7 @@ def pretrain(
145147 Note, by default it will use the tokenizer from the model.
146148
147149 Args:
148- model (pl.LightningModule): The model to be pretrained.
150+ model (Union[ pl.LightningModule, AnyPath] ): The model to be pretrained or a path to the NeMo 2 checkpoint .
149151 data (pl.LightningDataModule): The data module containing pretraining data.
150152 trainer (Trainer): The trainer instance configured with a MegatronStrategy.
151153 log (NeMoLogger): A nemologger instance.
@@ -166,6 +168,7 @@ def pretrain(
166168 >>> llm.pretrain(model, data, trainer)
167169 PosixPath('/path/to/log_dir')
168170 """
171+ model = _load_model_from_path (model )
169172 _validate_config (model , data , trainer , log = log , resume = resume , optim = optim )
170173
171174 return train (
@@ -181,28 +184,33 @@ def pretrain(
181184
182185@run .cli .entrypoint (namespace = "llm" )
183186def finetune (
184- model : pl .LightningModule ,
187+ model : Union [ pl .LightningModule , AnyPath ] ,
185188 data : pl .LightningDataModule ,
186189 trainer : Trainer ,
187190 log : Annotated [Optional [NeMoLogger ], run .Config [NeMoLogger ]] = None ,
188191 resume : Annotated [Optional [AutoResume ], run .Config [AutoResume ]] = None ,
189192 optim : Optional [OptimizerModule ] = None ,
190193 peft : Optional [Union [PEFT , ModelTransform , Callable ]] = None ,
194+ tokenizer : Optional [TokenizerType ] = "model" ,
191195) -> Path :
192196 """
193197 Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT.
194198
195199 Note, by default it will use the tokenizer from the model.
196200
197201 Args:
198- model (pl.LightningModule): The model to be finetuned.
202+ model (Union[ pl.LightningModule, AnyPath] ): The model to be finetuned.
199203 data (pl.LightningDataModule): The data module containing finetuning data.
200204 trainer (Trainer): The trainer instance configured with a MegatronStrategy.
201205 log (NeMoLogger): A nemologger instance.
202206 resume (Optional[AutoResume]): Resume training from a checkpoint.
203207 optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default
204208 optimizer from the model will be used.
205209 peft (Optional[PEFT]): A PEFT (Parameter-Efficient Fine-Tuning) configuration to be applied.
210+ tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model'
211+ or an instance of TokenizerSpec. If 'data' uses the data loader's tokenizer instead of the tokenizer
212+ from the model checkpoint, which is useful for expanding vocabulary or adding special tokens
213+ (such as chat template tokens).
206214
207215 Returns:
208216 Path: The directory path where finetuning artifacts are saved.
@@ -217,7 +225,7 @@ def finetune(
217225 >>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()])
218226 PosixPath('/path/to/log_dir')
219227 """
220-
228+ model = _load_model_from_path ( model )
221229 _validate_config (model , data , trainer , log = log , resume = resume , optim = optim , model_transform = peft )
222230 return train (
223231 model = model ,
@@ -226,7 +234,7 @@ def finetune(
226234 log = log ,
227235 resume = resume ,
228236 optim = optim ,
229- tokenizer = "model" ,
237+ tokenizer = tokenizer ,
230238 model_transform = peft ,
231239 )
232240
@@ -630,6 +638,7 @@ def deploy(
630638 the trtllm backend).
631639 """
632640 import os
641+
633642 import uvicorn
634643
635644 from nemo .deploy import DeployPyTriton
@@ -1345,3 +1354,9 @@ def _build_directory_tree(path, tree=None, root_name=None):
13451354 tree .add (f"[white]{ item .name } [/white]" )
13461355
13471356 return tree
1357+
1358+
1359+ def _load_model_from_path (model : Union [pl .LightningModule , AnyPath ]):
1360+ if isinstance (model , AnyPath ):
1361+ model = io .load_context (ckpt_to_context_subdir (model ), subpath = "model" )
1362+ return model
0 commit comments