Skip to content

Commit bb467bd

Browse files
committed
examples
1 parent 6959d19 commit bb467bd

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

examples/train-callback.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
from litmodels import upload_model
55
from sample_model import LitAutoEncoder
66

7+
# Define the model name - this should be unique to your model
8+
# The format is <organization>/<teamspace>/<model-name>
9+
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback"
10+
711

812
class UploadModelCallback(Callback):
913
def on_train_epoch_end(self, trainer, pl_module):
1014
# Get the best model path from the checkpoint callback
1115
best_model_path = trainer.checkpoint_callback.best_model_path
1216
if best_model_path:
1317
print(f"Uploading model: {best_model_path}")
14-
upload_model(path=best_model_path, name="jirka/kaggle/lit-auto-encoder-callback")
18+
upload_model(path=best_model_path, name=MY_MODEL_NAME)
1519

1620

1721
if __name__ == "__main__":

examples/train-resume.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@
44
from litmodels import download_model
55
from sample_model import LitAutoEncoder
66

7+
# Define the model name - this should be unique to your model
8+
# The format is <organization>/<teamspace>/<model-name>:<model-version>
9+
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback:latest"
10+
11+
712
if __name__ == "__main__":
813
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
914
train, val = data.random_split(dataset, [55000, 5000])
1015

11-
model_path = download_model(name="jirka/kaggle/lit-auto-encoder-simple:latest", download_dir="my_models")
16+
model_path = download_model(name=MY_MODEL_NAME, download_dir="my_models")
1217
print(f"model: {model_path}")
1318
# autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint_path=model_path)
1419

examples/train-simple.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from litmodels import upload_model
66
from sample_model import LitAutoEncoder
77

8+
# Define the model name - this should be unique to your model
9+
# The format is <organization>/<teamspace>/<model-name>
10+
MY_MODEL_NAME = "jirka/kaggle/lit-auto-encoder-callback"
11+
12+
813
if __name__ == "__main__":
914
dataset = tv.datasets.MNIST(".", download=True, transform=tv.transforms.ToTensor())
1015
train, val = data.random_split(dataset, [55000, 5000])
@@ -30,4 +35,4 @@
3035
data.DataLoader(val, batch_size=256),
3136
)
3237
print(f"last: {vars(checkpoint_callback)}")
33-
upload_model(path=checkpoint_callback.last_model_path, name="jirka/kaggle/lit-auto-encoder-simple")
38+
upload_model(path=checkpoint_callback.last_model_path, name=MY_MODEL_NAME)

0 commit comments

Comments
 (0)