1313
1414import logging
1515import os
16- from typing import Any , Dict , List , Literal , Optional , Tuple , Union
16+ from dataclasses import dataclass
17+ from typing import Any , Dict , List , Optional , Tuple , Union
1718
1819from lightning .data .datasets .env import _DistributedEnv
1920from lightning .data .streaming .constants import (
2021 _INDEX_FILENAME ,
21- _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 ,
22+ _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 ,
2223 _TORCH_GREATER_EQUAL_2_1_0 ,
2324)
2425from lightning .data .streaming .item_loader import BaseItemLoader
2526from lightning .data .streaming .reader import BinaryReader
2627from lightning .data .streaming .sampler import ChunkedIndex
2728from lightning .data .streaming .writer import BinaryWriter
2829
29- if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_46 :
30- from lightning_cloud .resolver import _find_remote_dir , _try_create_cache_dir
31-
3230logger = logging .Logger (__name__ )
3331
32+ if _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 :
33+ from lightning_cloud .resolver import _resolve_dir
34+
35+
36+ @dataclass
37+ class Dir :
38+ """Holds a directory path and possibly its associated remote URL."""
39+
40+ path : str
41+ url : Optional [str ] = None
42+
3443
3544class Cache :
3645 def __init__ (
3746 self ,
38- cache_dir : Optional [str ] = None ,
39- remote_dir : Optional [str ] = None ,
40- name : Optional [str ] = None ,
41- version : Optional [Union [int , Literal ["latest" ]]] = "latest" ,
47+ input_dir : Optional [Union [str , Dir ]],
4248 compression : Optional [str ] = None ,
4349 chunk_size : Optional [int ] = None ,
4450 chunk_bytes : Optional [Union [int , str ]] = None ,
@@ -48,9 +54,7 @@ def __init__(
4854 together in order to accelerate fetching.
4955
5056 Arguments:
51- cache_dir: The path to where the chunks will be stored.
52- remote_dir: The path to a remote folder where the data are located.
53- The scheme needs to be added to the path.
57+ input_dir: The path to where the chunks will be or are stored.
5458 name: The name of dataset in the cloud.
5559 version: The version of the dataset in the cloud to use. By default, we will use the latest.
5660 compression: The name of the algorithm to reduce the size of the chunks.
@@ -63,25 +67,20 @@ def __init__(
6367 if not _TORCH_GREATER_EQUAL_2_1_0 :
6468 raise ModuleNotFoundError ("PyTorch version 2.1 or higher is required to use the cache." )
6569
66- self ._cache_dir = cache_dir = str (cache_dir ) if cache_dir else _try_create_cache_dir (name )
67- if not remote_dir :
68- remote_dir , has_index_file = _find_remote_dir (name , version )
69-
70- # When the index exists, we don't care about the chunk_size anymore.
71- if has_index_file and (chunk_size is None and chunk_bytes is None ):
72- chunk_size = 2
73-
74- # Add the version to the cache_dir to avoid collisions.
75- if remote_dir and os .path .basename (remote_dir ).startswith ("version_" ):
76- cache_dir = os .path .join (cache_dir , os .path .basename (remote_dir ))
77-
78- if cache_dir :
79- os .makedirs (cache_dir , exist_ok = True )
80-
81- self ._cache_dir = cache_dir
82-
83- self ._writer = BinaryWriter (cache_dir , chunk_size = chunk_size , chunk_bytes = chunk_bytes , compression = compression )
84- self ._reader = BinaryReader (cache_dir , remote_dir = remote_dir , compression = compression , item_loader = item_loader )
70+ if not _LIGHTNING_CLOUD_GREATER_EQUAL_0_5_48 :
71+ raise ModuleNotFoundError ("Lightning Cloud 0.5.48 or higher is required to use the cache." )
72+
73+ input_dir = _resolve_dir (input_dir )
74+ self ._cache_dir = input_dir .path
75+ self ._writer = BinaryWriter (
76+ self ._cache_dir , chunk_size = chunk_size , chunk_bytes = chunk_bytes , compression = compression
77+ )
78+ self ._reader = BinaryReader (
79+ self ._cache_dir ,
80+ remote_input_dir = input_dir .url ,
81+ compression = compression ,
82+ item_loader = item_loader ,
83+ )
8584 self ._is_done = False
8685 self ._distributed_env = _DistributedEnv .detect ()
8786
0 commit comments