Is passing model as an argument to LitModel a bad practise? #8648
Replies: 2 comments
-
Dear @alexyalunin, This is actually a recommended practice as you brilliantly discovered yourself :) A slightly better way is to provide the config to generate the model, so it is serializable and easy to re-create from checkpoints: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/domain_templates/computer_vision_fine_tuning.py#L160 This is the pattern we adopted within Lightning Flash, our high level library: https://github.com/PyTorchLightning/lightning-flash/blob/master/flash_examples/image_classification.py You should give it a try. import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
# 1. Create the DataModule
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "./data")
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
)
# 2. Build the task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
# 4. Predict what's on a few images! ants or bees?
predictions = model.predict([
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg",
"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg",
])
print(predictions)
# 5. Save the model!
trainer.save_checkpoint("image_classification_model.pt") Best, |
Beta Was this translation helpful? Give feedback.
-
Dear @tchaton, along those lines: I would like to check in if I can use from pytorch_lightning.core.mixins import HyperparametersMixin
class MyModel(nn.Module, HyperparametersMixin):
def __init__(...):
super().__init__()
self.save_hyperparameters() # Logs to self.hparams only, not to the logger (since there isn't any yet)
class MyModule(pl.LightningModule):
def __init__(model: nn.Module):
super().__init__()
self.save_hyperparameters(ignore="model") # Otherwise logs with "model" prefix, which might be what you want
self.save_hyperparameters(model.hparams)
model = MyModel(...)
module = MyModule(model) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
I test different models on the same task, i.e models with same input and output. Thus, the training/validation steps, metrics, are all the same. I decided to pass model as an argument, however now it seems like I can't save_hyperparameters(), I assume the whole model is treated as hyperparameter and is passed to tensorboard logger which then hangs. Is passing model as an argument to LitModel a bad practise?
Beta Was this translation helpful? Give feedback.
All reactions