Skip to content

Commit c370a89

Browse files
committed
Fix compatibility with legacy file URI format
- Handle both proper file URIs (file:///path) and legacy format (file:/path) - Proper URIs use urlparse/url2pathname for Windows compatibility - Legacy format used by constructor returns path as-is - Update tests to cover both formats and Windows behavior
1 parent 4cd31fe commit c370a89

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

src/lightning/pytorch/loggers/mlflow.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,18 @@ def save_dir(self) -> Optional[str]:
300300
301301
"""
302302
if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):
303-
from urllib.parse import urlparse
304-
from urllib.request import url2pathname
305-
306-
parsed_uri = urlparse(self._tracking_uri)
307-
return url2pathname(parsed_uri.path)
303+
# Handle both proper file URIs (file:///path) and legacy format (file:/path)
304+
uri_without_prefix = self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :]
305+
306+
# If it starts with ///, it's a proper file URI, use urlparse
307+
if uri_without_prefix.startswith("///"):
308+
from urllib.parse import urlparse
309+
from urllib.request import url2pathname
310+
311+
parsed_uri = urlparse(self._tracking_uri)
312+
return url2pathname(parsed_uri.path)
313+
# Legacy format: file:/path or file:./path - return as-is
314+
return uri_without_prefix
308315
return None
309316

310317
@property

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -432,30 +432,38 @@ def test_set_tracking_uri(mlflow_mock):
432432
@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
433433
def test_mlflow_logger_save_dir_file_uri_handling(mlflow_mock):
434434
"""Test that save_dir correctly handles file URIs, especially on Windows."""
435-
# Test Unix-style absolute file URI
436-
logger = MLFlowLogger(tracking_uri="file:///home/user/mlruns")
437-
expected_unix = "/home/user/mlruns"
438-
assert logger.save_dir == expected_unix
439-
440-
# Test Windows-style absolute file URI
441-
logger_win = MLFlowLogger(tracking_uri="file:///C:/Dev/example/mlruns")
442-
# On Windows, url2pathname converts file:///C:/path to C:\path
443-
# On Unix, it converts to /C:/path, but we test the actual behavior
444435
import platform
445436

437+
# Test proper Windows-style absolute file URI (the main fix)
438+
logger_win = MLFlowLogger(tracking_uri="file:///C:/Dev/example/mlruns")
439+
result_win = logger_win.save_dir
446440
expected_win = "C:\\Dev\\example\\mlruns" if platform.system() == "Windows" else "/C:/Dev/example/mlruns"
447-
assert logger_win.save_dir == expected_win
441+
assert result_win == expected_win
448442

449-
# Test relative file URI
443+
# Test proper Unix-style absolute file URI
444+
logger_unix = MLFlowLogger(tracking_uri="file:///home/user/mlruns")
445+
result_unix = logger_unix.save_dir
446+
expected_unix = "\\home\\user\\mlruns" if platform.system() == "Windows" else "/home/user/mlruns"
447+
assert result_unix == expected_unix
448+
449+
# Test proper file URI with special characters and spaces
450+
logger_special = MLFlowLogger(tracking_uri="file:///path/with%20spaces/mlruns")
451+
result_special = logger_special.save_dir
452+
expected_special = "\\path\\with spaces\\mlruns" if platform.system() == "Windows" else "/path/with spaces/mlruns"
453+
assert result_special == expected_special
454+
455+
# Test legacy format used by constructor (file:/path - should return as-is)
456+
logger_legacy = MLFlowLogger(tracking_uri="file:/tmp/mlruns")
457+
result_legacy = logger_legacy.save_dir
458+
expected_legacy = "/tmp/mlruns"
459+
assert result_legacy == expected_legacy
460+
461+
# Test legacy relative format
450462
logger_rel = MLFlowLogger(tracking_uri="file:./mlruns")
463+
result_rel = logger_rel.save_dir
451464
expected_rel = "./mlruns"
452-
assert logger_rel.save_dir == expected_rel
465+
assert result_rel == expected_rel
453466

454467
# Test non-file URI (should return None)
455468
logger_http = MLFlowLogger(tracking_uri="http://localhost:8080")
456469
assert logger_http.save_dir is None
457-
458-
# Test file URI with special characters and spaces
459-
logger_special = MLFlowLogger(tracking_uri="file:///path/with%20spaces/mlruns")
460-
expected_special = "/path/with spaces/mlruns"
461-
assert logger_special.save_dir == expected_special

0 commit comments

Comments
 (0)