Skip to content

Commit aca3e2e

Browse files
committed
fix case-sensitive model name (#20661)
(cherry picked from commit 851d022)
1 parent c8f30de commit aca3e2e

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/lightning/pytorch/utilities/model_registry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,24 @@ def _parse_registry_model_version(ckpt_path: Optional[_PATH]) -> tuple[str, str]
6363
('model-name', '1.0')
6464
>>> _parse_registry_model_version("registry:model-name")
6565
('model-name', '')
66-
>>> _parse_registry_model_version("registry:version:v2")
66+
>>> _parse_registry_model_version("registry:VERSION:v2")
6767
('', 'v2')
6868
6969
"""
7070
if not ckpt_path or not _is_registry(ckpt_path):
7171
raise ValueError(f"Invalid registry path: {ckpt_path}")
7272

7373
# Split the path by ':'
74-
parts = str(ckpt_path).lower().split(":")
74+
parts = str(ckpt_path).split(":")
7575
# Default values
7676
model_name, version = "", ""
7777

7878
# Extract the model name and version based on the parts
79-
if len(parts) >= 2 and parts[1] != "version":
79+
if len(parts) >= 2 and parts[1].lower() != "version":
8080
model_name = parts[1]
81-
if len(parts) == 3 and parts[1] == "version":
81+
if len(parts) == 3 and parts[1].lower() == "version":
8282
version = parts[2]
83-
elif len(parts) == 4 and parts[2] == "version":
83+
elif len(parts) == 4 and parts[2].lower() == "version":
8484
version = parts[3]
8585

8686
return model_name, version

0 commit comments

Comments
 (0)