diff --git a/_requirements/test.txt b/_requirements/test.txt index 26ce2e8..4b56175 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -3,3 +3,5 @@ pytest >=6.0 pytest-cov pytest-mock mypy ==1.13.0 + +pytorch-lightning >=2.0 diff --git a/src/litmodels/cloud_io.py b/src/litmodels/cloud_io.py index 8fa8139..2dad671 100644 --- a/src/litmodels/cloud_io.py +++ b/src/litmodels/cloud_io.py @@ -73,7 +73,7 @@ def upload_model( def download_model( name: str, - download_dir: Optional[str] = None, + download_dir: str = ".", progress_bar: bool = True, ) -> str: """Download a checkpoint from the model store.