diff --git a/README.md b/README.md index f597204..e0249fb 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,10 @@ from lightning.pytorch.callbacks import ModelCheckpoint from litmodels import upload_model from litmodels.demos import BoringModel +# Define the model name - this should be unique to your model +# The format is // +MY_MODEL_NAME = "jirka/kaggle/lit-boring-model" + class LitModel(BoringModel): def training_step(self, batch, batch_idx: int): @@ -48,32 +52,14 @@ class LitModel(BoringModel): return {"loss": loss} -# Define the model name - this should be unique to your model -# The format is // -MY_MODEL_NAME = "jirka/kaggle/lit-boring-model" - -# Define the model -model = LitModel() -# Save the best model based on validation loss -checkpoint_callback = ModelCheckpoint( - monitor="train_loss", # Metric to monitor - save_top_k=1, # Only save the best model (use -1 to save all) - mode="min", # 'min' for loss, 'max' for accuracy - save_last=True, # Additionally save the last checkpoint - dirpath="my_checkpoints/", # Directory to save checkpoints - filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename -) - -# Train the model -trainer = Trainer( - max_epochs=2, - callbacks=[checkpoint_callback], -) -trainer.fit(model) +# Configure Lightning Trainer +trainer = Trainer(max_epochs=2) +# Define the model and train it +trainer.fit(LitModel()) # Upload the best model to cloud storage -print(f"last: {vars(checkpoint_callback)}") -upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME) +checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") +upload_model(model=checkpoint_path, name=MY_MODEL_NAME) ``` To load the model, use the `download_model` function. @@ -84,6 +70,11 @@ from litmodels import download_model from litmodels.demos import BoringModel +# Define the model name - this should be unique to your model +# The format is //: +MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest" + + class LitModel(BoringModel): def training_step(self, batch, batch_idx: int): loss = self.step(batch) @@ -92,20 +83,13 @@ class LitModel(BoringModel): return {"loss": loss} -# Define the model name - this should be unique to your model -# The format is //: -MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest" - # Load the model from cloud storage -model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models") -print(f"model: {model_path}") +checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models") +print(f"model: {checkpoint_path}") # Train the model with extended training period trainer = Trainer(max_epochs=4) -trainer.fit( - LitModel(), - ckpt_path=model_path, -) +trainer.fit(LitModel(), ckpt_path=checkpoint_path) ``` You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage. diff --git a/examples/train-simple.py b/examples/train-simple.py index ee688bd..060d8d3 100644 --- a/examples/train-simple.py +++ b/examples/train-simple.py @@ -1,7 +1,6 @@ import torch.utils.data as data import torchvision as tv from lightning import Trainer -from lightning.pytorch.callbacks import ModelCheckpoint from litmodels import upload_model from sample_model import LitAutoEncoder @@ -15,24 +14,13 @@ train, val = data.random_split(dataset, [55000, 5000]) autoencoder = LitAutoEncoder() - # Save the best model based on validation loss - checkpoint_callback = ModelCheckpoint( - monitor="val_loss", # Metric to monitor - save_top_k=1, # Only save the best model (use -1 to save all) - mode="min", # 'min' for loss, 'max' for accuracy - save_last=True, # Additionally save the last checkpoint - dirpath="my_checkpoints/", # Directory to save checkpoints - filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename - ) - trainer = Trainer( - max_epochs=2, - callbacks=[checkpoint_callback], - ) + trainer = Trainer(max_epochs=2) trainer.fit( autoencoder, data.DataLoader(train, batch_size=256), data.DataLoader(val, batch_size=256), ) - print(f"last: {vars(checkpoint_callback)}") - upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME) + checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path") + print(f"best: {checkpoint_path}") + upload_model(model=checkpoint_path, name=MY_MODEL_NAME)