Exporting PyTorch Lightning model to ONNX format not working #10063
-
I am using Jupyter Lab to run. It has pre-installed tf2.3_py3.6 kernel installed in it. It has 2 GPUS in it.
Here is the screenshot of my model and it got interrupted due to connection issue. I am saving the best model in checkpoint.
Here is the DataModule Class
Here is the model class:
Sample Data
Model
ONNX code
Error
|
Beta Was this translation helpful? Give feedback.
Replies: 3 comments
-
Hi A google search reveals some help on this issue here: Citing the thread there
If we look closer at your code, we see that loss=0 and labels=None. def forward(self, input_ids, attention_mask, labels=None):
output = self.bert(input_ids, attention_mask=attention_mask)
output = self.classifier(output.pooler_output)
output = torch.sigmoid(output)
loss = 0
if labels is not None:
loss = self.criterion(output, labels)
return loss, output the if condition does not hold, so the part of your output (the loss) cannot be traced back to any inputs by onnx. Change your code to something like this and try again please: if labels is not None:
loss = self.criterion(output, labels)
return loss, output
return output |
Beta Was this translation helpful? Give feedback.
-
@awaelchli Ok. This solves my problem. However I have one clarification : In code to save in ONNX format:
Is this correct way to pass
Secondly, does this saves my best model which is trained above using pytorch lightning?? |
Beta Was this translation helpful? Give feedback.
-
Great! Glad the main issue is resolved. I'm converting this issue to a discussion thread. Also check out the PL docs for the to_onnx method on the LightningModule
Yes, looks good.
When you run trainer.fit(model) it will train but when it finishes you may not end up with the best weights. You can load the model with the best weights by accessing the path
Check the ModelCheckpoint callback docs how to use it :) |
Beta Was this translation helpful? Give feedback.
Hi
A google search reveals some help on this issue here:
pytorch/pytorch#31591
Citing the thread there
If we look closer at your code, we see that loss=0 and labels=None.
the if condition does not hold, so the part of your output (the loss) can…