Skip to content

Commit 71ef218

Browse files
deependujhaCopilot
andauthored
feat: Add support for path in map fn (Lightning-AI#582)
* fix: enhance directory resolution and add time template test cases * update * update * Update tests/processing/test_functions.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/processing/test_functions.py * Update tests/processing/test_functions.py * Update tests/streaming/test_resolver.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/processing/test_functions.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update * update --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 96238b6 commit 71ef218

File tree

4 files changed

+83
-15
lines changed

4 files changed

+83
-15
lines changed

src/litdata/processing/functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def prepare_item(self, item_metadata: Any) -> Any:
207207
def map(
208208
fn: Callable[[str, Any], None],
209209
inputs: Union[Sequence[Any], StreamingDataLoader],
210-
output_dir: Union[str, Dir],
211-
input_dir: Optional[str] = None,
210+
output_dir: Union[str, Path, Dir],
211+
input_dir: Optional[Union[str, Path]] = None,
212212
weights: Optional[List[int]] = None,
213213
num_workers: Optional[int] = None,
214214
fast_dev_run: Union[bool, int] = False,

src/litdata/streaming/resolver.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,28 @@ class Dir:
3838
url: Optional[str] = None
3939

4040

41-
def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:
41+
def _resolve_dir(dir_path: Optional[Union[str, Path, Dir]]) -> Dir:
4242
if isinstance(dir_path, Dir):
4343
return Dir(path=str(dir_path.path) if dir_path.path else None, url=str(dir_path.url) if dir_path.url else None)
4444

4545
if dir_path is None:
4646
return Dir()
4747

48-
if not isinstance(dir_path, str):
49-
raise ValueError(f"`dir_path` must be a `Dir` or a string, got: {dir_path}")
48+
if not isinstance(dir_path, (str, Path)):
49+
raise ValueError(f"`dir_path` must be either a string, Path, or Dir, got: {dir_path}")
5050

51-
assert isinstance(dir_path, str)
51+
if isinstance(dir_path, str):
52+
cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
53+
if dir_path.startswith(cloud_prefixes):
54+
return Dir(path=None, url=dir_path)
5255

53-
cloud_prefixes = ("s3://", "gs://", "azure://", "hf://")
54-
if dir_path.startswith(cloud_prefixes):
55-
return Dir(path=None, url=dir_path)
56+
if dir_path.startswith("local:"):
57+
return Dir(path=None, url=dir_path)
5658

57-
if dir_path.startswith("local:"):
58-
return Dir(path=None, url=dir_path)
59-
60-
dir_path = _resolve_time_template(dir_path)
59+
dir_path = _resolve_time_template(dir_path)
6160

6261
dir_path_absolute = str(Path(dir_path).absolute().resolve())
62+
dir_path = str(dir_path) # Convert to string if it was a Path object
6363

6464
if dir_path_absolute.startswith("/teamspace/studios/this_studio"):
6565
return Dir(path=dir_path_absolute, url=None)
@@ -345,7 +345,22 @@ def _get_lightning_cloud_url() -> str:
345345

346346

347347
def _resolve_time_template(path: str) -> str:
348-
match = re.search("^.*{%.*}$", path)
348+
"""Resolves a datetime pattern in the given path string.
349+
350+
If the path contains a placeholder in the form `{%Y-%m-%d}`, it replaces it
351+
with the current date/time formatted using the specified `strftime` pattern.
352+
353+
Example:
354+
Input: "/logs/log_{%Y-%m-%d}.txt"
355+
Output (on May 5, 2025): "/logs/log_2025-05-05.txt"
356+
357+
Args:
358+
path (str): The file path containing an optional datetime placeholder.
359+
360+
Returns:
361+
str: The path with the datetime placeholder replaced by the current timestamp.
362+
"""
363+
match = re.search("^.*{%.*}.*$", path)
349364
if match is None:
350365
return path
351366

tests/processing/test_functions.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from PIL import Image
1616

17-
from litdata import StreamingDataset, merge_datasets, optimize, walk
17+
from litdata import StreamingDataset, map, merge_datasets, optimize, walk
1818
from litdata.processing.functions import _get_input_dir, _resolve_dir
1919
from litdata.streaming.cache import Cache
2020
from litdata.utilities.encryption import FernetEncryption, RSAEncryption
@@ -58,6 +58,44 @@ def test_get_input_dir_with_s3_path():
5858
assert input_dir.url == "s3://my_bucket/my_folder"
5959

6060

61+
def update_msg(file_path: Path, output_dir: Path):
62+
with open(os.path.join(output_dir, file_path.name), "w") as f:
63+
f.write("Bonjour!")
64+
65+
66+
def test_map_with_path(tmpdir):
67+
input_dir = Path(tmpdir) / "input_dir"
68+
output_dir = Path(tmpdir) / "output_dir"
69+
70+
os.makedirs(input_dir, exist_ok=True)
71+
os.makedirs(output_dir, exist_ok=True)
72+
73+
for i in range(5):
74+
filepath = os.path.join(input_dir, f"{i}.txt")
75+
with open(filepath, "w") as f:
76+
f.write("hello world!")
77+
78+
# read all files in the input directory, and assert it contains hello world!
79+
for file in input_dir.iterdir():
80+
with open(file) as f:
81+
content = f.read()
82+
assert content == "hello world!"
83+
84+
inputs = list(input_dir.iterdir()) # List all files in the directory
85+
86+
map(
87+
fn=update_msg,
88+
inputs=inputs,
89+
output_dir=output_dir,
90+
)
91+
92+
# read all files in the output directory, and assert it contains Bonjour!
93+
for file in output_dir.iterdir():
94+
with open(file) as f:
95+
content = f.read()
96+
assert content == "Bonjour!"
97+
98+
6199
def compress(index):
62100
return index, index**2
63101

tests/streaming/test_resolver.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import datetime
12
import sys
23
from pathlib import Path
34
from unittest import mock
@@ -388,3 +389,17 @@ def test_resolve_dir_absolute(tmp_path, monkeypatch):
388389
link.symlink_to(src)
389390
assert link.resolve() == src
390391
assert resolver._resolve_dir(str(link)).path == str(src)
392+
393+
394+
def test_resolve_time_template():
395+
path_1 = "/logs/log_{%Y-%m}"
396+
path_2 = "/logs/my_logfile"
397+
path_3 = "/logs/log_{%Y-%m}/important"
398+
399+
current_datetime = datetime.datetime.now()
400+
curr_year = current_datetime.year
401+
curr_month = current_datetime.month
402+
403+
assert resolver._resolve_time_template(path_1) == f"/logs/log_{curr_year}-{curr_month:02d}"
404+
assert resolver._resolve_time_template(path_2) == path_2
405+
assert resolver._resolve_time_template(path_3) == f"/logs/log_{curr_year}-{curr_month:02d}/important"

0 commit comments

Comments
 (0)