Question: How to switch models ? #104
Replies: 13 comments 2 replies
-
@yukkyo thanks! i'm not quite understanding what you want to do. The normal flow is:
|
Beta Was this translation helpful? Give feedback.
-
@williamFalcon First of all, I don't want to write the same process every model (training_step(), validation_step(), tng_dataloader(), and so on). As you say, if I want to compare multiple models (and training_step() etc. are the same),
Is it correct? |
Beta Was this translation helpful? Give feedback.
-
oh, you have different systems? File 1: File 2: File 3: if want_model_1:
model = Model1()
else:
model = Model2()
trainer = Trainer(...)
trainer.fit(model) |
Beta Was this translation helpful? Give feedback.
-
@williamFalcon |
Beta Was this translation helpful? Give feedback.
-
i’m sorry, i’m not sure i understand. If the training_step is the same then define 1 Module. what are you trying to do? can you write pseudocode? |
Beta Was this translation helpful? Give feedback.
-
Sorry for not being able to speak well. In my case, I want to compare metrics of Network1 and Network2. Method1: Define LightningModule for each Networkclass LightningTemplateModel1(LightningModule):
"""
Sample model to show how to define a template
"""
def __init__(self, hparams):
# build model
self.__build_model1()
def __build_model(self):
# Define of Network1
# Long code
def forward(self, x):
# Long code
return logits
def training_step(self, data_batch, batch_i):
# Long code
def validation_step(self, data_batch, batch_i, dataloader_i):
# Long code
class LightningTemplateModel2(LightningModule):
"""
Sample model to show how to define a template
"""
def __init__(self, hparams):
# build model
self.__build_model2()
def __build_model(self):
# Define of Network2
# Long code
def forward(self, x):
# Long code
return logits
def training_step(self, data_batch, batch_i):
# Long code
def validation_step(self, data_batch, batch_i, dataloader_i):
# Long code Method2: Switch Networks by hparamsclass LightningTemplateModel(LightningModule):
def __init__(self, hparams):
# Switch model by hparams
if hparams.model_type == 'Network1':
self.__my_model = build_model1()
else:
self.__my_model = build_model2()
def forward(self, x):
return self.__mymodel(x)
def training_step(self, data_batch, batch_i):
# Long code
return
def validation_step(self, data_batch, batch_i, dataloader_i):
# Long code
return
@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--model_type', default='Network1', type=str)
return parser Method3: Use subclassclass MyLightningTemplateModel(LightningModule):
def __init__(self, hparams):
self.__build_model()
def __build_model(self):
raise NotImplementedError()
def forward(self, x):
raise NotImplementedError()
def training_step(self, data_batch, batch_i):
# Long code
return
def validation_step(self, data_batch, batch_i, dataloader_i):
# Long code
return
class LightningTemplateModel1(MyLightningTemplateModel):
def __init__(self, hparams):
self.__build_model()
def __build_model(self):
# Define of Network1
# Long code
return
def forward(self, x):
# Long code
class LightningTemplateModel2(MyLightningTemplateModel):
def __init__(self, hparams):
self.__build_model()
def __build_model(self):
# Define of Network2
# Long code
return
def forward(self, x):
# Long code |
Beta Was this translation helpful? Give feedback.
-
yeah, method1 is the intended pattern. |
Beta Was this translation helpful? Give feedback.
-
Thank you! I'll do my best to write! |
Beta Was this translation helpful? Give feedback.
-
wait actually. if model1 is just a network, and model2 is a different architecture then use method2. |
Beta Was this translation helpful? Give feedback.
-
I assume that Network1 is YOLO and Network2 is RetinaNet. |
Beta Was this translation helpful? Give feedback.
-
yeah. just use hparams to toggle. the lightningModule is meant for a full system. you are just specifying a particular architecture in a broader system |
Beta Was this translation helpful? Give feedback.
-
Thank you! That was a really big help! |
Beta Was this translation helpful? Give feedback.
-
@williamFalcon |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Thank you for the wonderful library! 😄
I looked at the example below and didn't know how to compare the models if the datasets were the same.
https://github.com/williamFalcon/pytorch-lightning/blob/master/examples/new_project_templates/lightning_module_template.py
There are several ways to do this.
Which is a good practice?
Or do you know a better way?
Beta Was this translation helpful? Give feedback.
All reactions