1313# limitations under the License.
1414import os
1515import re
16+ from typing import Optional
1617
1718from lightning_utilities import module_available
1819
1920import lightning .pytorch as pl
2021from lightning .fabric .utilities .imports import _IS_WINDOWS
22+ from lightning .fabric .utilities .types import _PATH
2123
2224# skip these test on Windows as the path notation differ
2325if _IS_WINDOWS :
2426 __doctest_skip__ = ["_determine_model_folder" ]
2527
2628
27- def _is_registry (text : str ) -> bool :
29+ def _is_registry (text : Optional [ _PATH ] ) -> bool :
2830 """Check if a string equals 'registry' or starts with 'registry:'.
2931
3032 Args:
@@ -48,7 +50,7 @@ def _is_registry(text: str) -> bool:
4850 return bool (re .match (pattern , text .lower ()))
4951
5052
51- def _parse_registry_model_version (ckpt_path : str ) -> tuple [str , str ]:
53+ def _parse_registry_model_version (ckpt_path : Optional [ _PATH ] ) -> tuple [str , str ]:
5254 """Parse the model version from a registry path.
5355
5456 Args:
@@ -84,7 +86,7 @@ def _parse_registry_model_version(ckpt_path: str) -> tuple[str, str]:
8486 return model_name , version
8587
8688
87- def _determine_model_name (ckpt_path : str , default_model_registry : str ) -> str :
89+ def _determine_model_name (ckpt_path : Optional [ _PATH ] , default_model_registry : Optional [ str ] ) -> str :
8890 """Determine the model name from the checkpoint path.
8991
9092 Args:
@@ -105,8 +107,10 @@ def _determine_model_name(ckpt_path: str, default_model_registry: str) -> str:
105107 # try to find model and version
106108 model_name , model_version = _parse_registry_model_version (ckpt_path )
107109 # omitted model name try to use the model registry from Trainer
108- if not model_name :
110+ if not model_name and default_model_registry :
109111 model_name = default_model_registry
112+ if not model_name :
113+ raise ValueError (f"Invalid model registry: '{ ckpt_path } '" )
110114 model_registry = model_name
111115 model_registry += f":{ model_version } " if model_version else ""
112116 return model_registry
@@ -137,7 +141,9 @@ def _determine_model_folder(model_name: str, default_root_dir: str) -> str:
137141 return local_model_dir
138142
139143
140- def find_model_local_ckpt_path (ckpt_path : str , default_model_registry : str , default_root_dir : str ) -> str :
144+ def find_model_local_ckpt_path (
145+ ckpt_path : Optional [_PATH ], default_model_registry : Optional [str ], default_root_dir : str
146+ ) -> str :
141147 """Find the local checkpoint path for a model."""
142148 model_registry = _determine_model_name (ckpt_path , default_model_registry )
143149 local_model_dir = _determine_model_folder (model_registry , default_root_dir )
0 commit comments