Skip to content

Commit b223878

Browse files
committed
typing
1 parent 81bd9ba commit b223878

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

src/lightning/pytorch/utilities/model_registry.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
# limitations under the License.
1414
import os
1515
import re
16+
from typing import Optional
1617

1718
from lightning_utilities import module_available
1819

1920
import lightning.pytorch as pl
2021
from lightning.fabric.utilities.imports import _IS_WINDOWS
22+
from lightning.fabric.utilities.types import _PATH
2123

2224
# skip these test on Windows as the path notation differ
2325
if _IS_WINDOWS:
2426
__doctest_skip__ = ["_determine_model_folder"]
2527

2628

27-
def _is_registry(text: str) -> bool:
29+
def _is_registry(text: Optional[_PATH]) -> bool:
2830
"""Check if a string equals 'registry' or starts with 'registry:'.
2931
3032
Args:
@@ -48,7 +50,7 @@ def _is_registry(text: str) -> bool:
4850
return bool(re.match(pattern, text.lower()))
4951

5052

51-
def _parse_registry_model_version(ckpt_path: str) -> tuple[str, str]:
53+
def _parse_registry_model_version(ckpt_path: Optional[_PATH]) -> tuple[str, str]:
5254
"""Parse the model version from a registry path.
5355
5456
Args:
@@ -84,7 +86,7 @@ def _parse_registry_model_version(ckpt_path: str) -> tuple[str, str]:
8486
return model_name, version
8587

8688

87-
def _determine_model_name(ckpt_path: str, default_model_registry: str) -> str:
89+
def _determine_model_name(ckpt_path: Optional[_PATH], default_model_registry: Optional[str]) -> str:
8890
"""Determine the model name from the checkpoint path.
8991
9092
Args:
@@ -105,8 +107,10 @@ def _determine_model_name(ckpt_path: str, default_model_registry: str) -> str:
105107
# try to find model and version
106108
model_name, model_version = _parse_registry_model_version(ckpt_path)
107109
# omitted model name try to use the model registry from Trainer
108-
if not model_name:
110+
if not model_name and default_model_registry:
109111
model_name = default_model_registry
112+
if not model_name:
113+
raise ValueError(f"Invalid model registry: '{ckpt_path}'")
110114
model_registry = model_name
111115
model_registry += f":{model_version}" if model_version else ""
112116
return model_registry
@@ -137,7 +141,9 @@ def _determine_model_folder(model_name: str, default_root_dir: str) -> str:
137141
return local_model_dir
138142

139143

140-
def find_model_local_ckpt_path(ckpt_path: str, default_model_registry: str, default_root_dir: str) -> str:
144+
def find_model_local_ckpt_path(
145+
ckpt_path: Optional[_PATH], default_model_registry: Optional[str], default_root_dir: str
146+
) -> str:
141147
"""Find the local checkpoint path for a model."""
142148
model_registry = _determine_model_name(ckpt_path, default_model_registry)
143149
local_model_dir = _determine_model_folder(model_registry, default_root_dir)

0 commit comments

Comments
 (0)