Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 18 additions & 34 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ from lightning.pytorch.callbacks import ModelCheckpoint
from litmodels import upload_model
from litmodels.demos import BoringModel

# Define the model name - this should be unique to your model
# The format is <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model"


class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
Expand All @@ -48,32 +52,14 @@ class LitModel(BoringModel):
return {"loss": loss}


# Define the model name - this should be unique to your model
# The format is <organization>/<teamspace>/<model-name>
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model"

# Define the model
model = LitModel()
# Save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
monitor="train_loss", # Metric to monitor
save_top_k=1, # Only save the best model (use -1 to save all)
mode="min", # 'min' for loss, 'max' for accuracy
save_last=True, # Additionally save the last checkpoint
dirpath="my_checkpoints/", # Directory to save checkpoints
filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename
)

# Train the model
trainer = Trainer(
max_epochs=2,
callbacks=[checkpoint_callback],
)
trainer.fit(model)
# Configure Lightning Trainer
trainer = Trainer(max_epochs=2)
# Define the model and train it
trainer.fit(LitModel())

# Upload the best model to cloud storage
print(f"last: {vars(checkpoint_callback)}")
upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME)
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
```

To load the model, use the `download_model` function.
Expand All @@ -84,6 +70,11 @@ from litmodels import download_model
from litmodels.demos import BoringModel


# Define the model name - this should be unique to your model
# The format is <organization>/<teamspace>/<model-name>:<model-version>
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest"


class LitModel(BoringModel):
def training_step(self, batch, batch_idx: int):
loss = self.step(batch)
Expand All @@ -92,20 +83,13 @@ class LitModel(BoringModel):
return {"loss": loss}


# Define the model name - this should be unique to your model
# The format is <organization>/<teamspace>/<model-name>:<model-version>
MY_MODEL_NAME = "jirka/kaggle/lit-boring-model:latest"

# Load the model from cloud storage
model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
print(f"model: {model_path}")
checkpoint_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
print(f"model: {checkpoint_path}")

# Train the model with extended training period
trainer = Trainer(max_epochs=4)
trainer.fit(
LitModel(),
ckpt_path=model_path,
)
trainer.fit(LitModel(), ckpt_path=checkpoint_path)
```

You can also use model store together with [LitLogger](https://github.com/gridai/lit-logger) to log your model to the cloud storage.
Expand Down
20 changes: 4 additions & 16 deletions examples/train-simple.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch.utils.data as data
import torchvision as tv
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from litmodels import upload_model
from sample_model import LitAutoEncoder

Expand All @@ -15,24 +14,13 @@
train, val = data.random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
# Save the best model based on validation loss
checkpoint_callback = ModelCheckpoint(
monitor="val_loss", # Metric to monitor
save_top_k=1, # Only save the best model (use -1 to save all)
mode="min", # 'min' for loss, 'max' for accuracy
save_last=True, # Additionally save the last checkpoint
dirpath="my_checkpoints/", # Directory to save checkpoints
filename="{epoch:02d}-{val_loss:.2f}", # Custom checkpoint filename
)

trainer = Trainer(
max_epochs=2,
callbacks=[checkpoint_callback],
)
trainer = Trainer(max_epochs=2)
trainer.fit(
autoencoder,
data.DataLoader(train, batch_size=256),
data.DataLoader(val, batch_size=256),
)
print(f"last: {vars(checkpoint_callback)}")
upload_model(model=checkpoint_callback.last_model_path, name=MY_MODEL_NAME)
checkpoint_path = getattr(trainer.checkpoint_callback, "best_model_path")
print(f"best: {checkpoint_path}")
upload_model(model=checkpoint_path, name=MY_MODEL_NAME)
Loading