-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix RichProgressBar loading when using LightningCLI
#21340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
|
Previously, the progress bar was not properly initialized when running through the This change ensures that the Minimal working example: import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.callbacks import RichProgressBar
# ==== Minimal Example Model and DataModule ====
INPUT_DIM = 10
OUTPUT_DIM = 1
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(INPUT_DIM, OUTPUT_DIM)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, _):
x, y = batch
loss = nn.functional.mse_loss(self(x), y)
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
class Data(pl.LightningDataModule):
def setup(self, stage=None):
n_samples = 100
x = torch.randn(n_samples, INPUT_DIM)
y = torch.randn(n_samples, OUTPUT_DIM)
self.ds = TensorDataset(x, y)
def train_dataloader(self):
return DataLoader(self.ds, batch_size=32)
# ==== CLI entrypoint ====
class CLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(RichProgressBar, "rich_progress_bar")
def main():
CLI(Model, Data)
if __name__ == "__main__":
main() |
| def __init__(self) -> None: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is unclear to me why this change is required. Data-classes generally don't require an init function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The redundant __init__() here is just a workaround to ensure that jsonargparse or typeshed_client can correctly parse RichProgressBarTheme when it is used as the default argument to RichProgressBar in the CLI, as shown in the script above.
Without this workaround, the script throws a key-validation error:
$ python3 demo.py fit
usage: test.py [-h] [-c CONFIG] [--print_config[=flags]] {fit,validate,test,predict} ...
error: Validation failed: Parser key "rich_progress_bar.theme.progress_bar":
Does not validate against any of the Union subtypes
Subtypes: [<class 'str'>, <class 'rich.style.Style'>]
Errors:
- Expected a <class 'str'>. Got value: Not a valid subclass of Style. Got value: None
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)
- Not a valid subclass of Style. Got value: Not a valid subclass of Style. Got value: None
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)
Given value type: <class 'ValueError'>
Given value: Not a valid subclass of Style. Got value: None
Subclass types expect one of:
- a class path (str)
- a dict with class_path entry
- a dict without class_path but with init_args entry (class path given previously)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah but then it should be a subclass of rich.style.Style actually and not a dataclass anymore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with the internal parsing mechanism, but as far as I know, either adding the redundant __init__() to this dataclass, or replacing all default values of arguments of RichProgressBarTheme typed Union[str, "Style"] with "" can solve this problem, where the latter one is clearly not ideal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The following fix correctly applies the options to the progress bar theme, but it still triggers a JsonargparseWarning (although it is no longer an error):
@dataclass
class RichProgressBarTheme:
description: Union[str, "Style"] = ""
progress_bar: Union[str, "Style"] = Style(color="#6206E0")
progress_bar_finished: Union[str, "Style"] = Style(color="#6206E0")
progress_bar_pulse: Union[str, "Style"] = Style(color="#6206E0")
batch_progress: Union[str, "Style"] = ""
time: Union[str, "Style"] = Style(dim=True)
processing_speed: Union[str, "Style"] = Style(dim=True, underline=True)
metrics: Union[str, "Style"] = Style(italic=True)
metrics_text_delimiter: str = " "
metrics_format: str = ".3f"There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mauvilsa could you maybe advise what's the best way to go here? Adding an init function to a data-class seems not very pythonic to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a recent version of jsonargparse for some reason broke the support for RichProgressBar, then the fix would belong in jsonargparse, not here. A "workaround" is generally not a good solution.
I still need to analyze what actually is happening. Once I have, I will comment here again.
📚 Documentation preview 📚: https://pytorch-lightning--21340.org.readthedocs.build/en/21340/