Skip to content

Commit cb9e5d9

Browse files
tchatonthomasawaelchli
authored andcommitted
Improve s3 client support (#18920)
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * Update src/lightning/data/streaming/client.py Co-authored-by: Adrian Wälchli <[email protected]> * update * update * update * update * update * update * update * update * update --------- Co-authored-by: thomas <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> (cherry picked from commit 6a0f992)
1 parent 372fec8 commit cb9e5d9

File tree

9 files changed

+151
-89
lines changed

9 files changed

+151
-89
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343

4444
# Data Utilities
4545
/examples/data/ @tchaton @nohalon @justusschock @lantiga
46-
/src/lightning/data/ @tchaton @nohalon @justusschock @lantiga
47-
/tests/tests_data @tchaton @nohalon @justusschock @lantiga
46+
/src/lightning/data/ @tchaton
47+
/tests/tests_data @tchaton
4848

4949
# Lightning Fabric
5050
/src/lightning/fabric @awaelchli @carmocca @justusschock

index_1.txt

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
from time import time
3+
from typing import Any, Optional
4+
5+
from lightning.data.streaming.constants import _BOTO3_AVAILABLE
6+
7+
if _BOTO3_AVAILABLE:
8+
import boto3
9+
import botocore
10+
from botocore.credentials import InstanceMetadataProvider
11+
from botocore.utils import InstanceMetadataFetcher
12+
13+
14+
class S3Client:
15+
# TODO: Generalize to support more cloud providers.
16+
17+
def __init__(self, refetch_interval: int = 3300) -> None:
18+
self._refetch_interval = refetch_interval
19+
self._last_time: Optional[float] = None
20+
self._has_cloud_space_id: bool = "LIGHTNING_CLOUD_SPACE_ID" in os.environ
21+
self._client: Optional[Any] = None
22+
23+
@property
24+
def client(self) -> Any:
25+
if not self._has_cloud_space_id:
26+
if self._client is None:
27+
self._client = boto3.client(
28+
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
29+
)
30+
return self._client
31+
32+
# Re-generate credentials for EC2
33+
if self._last_time is None or (time() - self._last_time) > self._refetch_interval:
34+
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
35+
credentials = provider.load()
36+
self._client = boto3.client(
37+
"s3",
38+
aws_access_key_id=credentials.access_key,
39+
aws_secret_access_key=credentials.secret_key,
40+
aws_session_token=credentials.token,
41+
config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
42+
)
43+
self._last_time = time()
44+
45+
return self._client

src/lightning/data/streaming/data_processor.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from lightning import seed_everything
2020
from lightning.data.streaming import Cache
21+
from lightning.data.streaming.client import S3Client
2122
from lightning.data.streaming.constants import (
2223
_BOTO3_AVAILABLE,
2324
_DEFAULT_FAST_DEV_RUN_ITEMS,
@@ -40,7 +41,6 @@
4041
from lightning_cloud.resolver import _LightningSrcResolver, _LightningTargetResolver
4142

4243
if _BOTO3_AVAILABLE:
43-
import boto3
4444
import botocore
4545

4646
logger = logging.Logger(__name__)
@@ -74,8 +74,8 @@ def _get_home_folder() -> str:
7474
def _get_cache_dir(name: Optional[str]) -> str:
7575
"""Returns the cache directory used by the Cache to store the chunks."""
7676
if name is None:
77-
return _get_cache_folder()
78-
return os.path.join(_get_cache_folder(), name)
77+
return os.path.join(_get_cache_folder(), "chunks")
78+
return os.path.join(_get_cache_folder(), "chunks", name)
7979

8080

8181
def _get_cache_data_dir(name: Optional[str]) -> str:
@@ -85,10 +85,6 @@ def _get_cache_data_dir(name: Optional[str]) -> str:
8585
return os.path.join(_get_cache_folder(), "data", name)
8686

8787

88-
def _get_s3_client() -> Any:
89-
return boto3.client("s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "standard"}))
90-
91-
9288
def _wait_for_file_to_exist(s3: Any, obj: parse.ParseResult, sleep_time: int = 2) -> Any:
9389
"""This function check."""
9490
while True:
@@ -105,7 +101,7 @@ def _download_data_target(
105101
input_dir: str, remote_input_dir: str, cache_dir: str, queue_in: Queue, queue_out: Queue
106102
) -> None:
107103
"""This function is used to download data from a remote directory to a cache directory to optimise reading."""
108-
s3 = _get_s3_client()
104+
s3 = S3Client()
109105

110106
while True:
111107
# 2. Fetch from the queue
@@ -137,7 +133,7 @@ def _download_data_target(
137133
os.makedirs(dirpath, exist_ok=True)
138134

139135
with open(local_path, "wb") as f:
140-
s3.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
136+
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
141137

142138
elif os.path.isfile(remote_path):
143139
copyfile(remote_path, local_path)
@@ -176,7 +172,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, remote_
176172
obj = parse.urlparse(remote_output_dir)
177173

178174
if obj.scheme == "s3":
179-
s3 = _get_s3_client()
175+
s3 = S3Client()
180176

181177
while True:
182178
local_filepath: Optional[str] = upload_queue.get()
@@ -190,10 +186,14 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, remote_
190186
local_filepath = os.path.join(cache_dir, local_filepath)
191187

192188
if obj.scheme == "s3":
193-
s3.upload_file(
194-
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
195-
)
196-
elif os.path.isdir(remote_output_dir):
189+
try:
190+
s3.client.upload_file(
191+
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
192+
)
193+
except Exception as e:
194+
print(e)
195+
return
196+
if os.path.isdir(remote_output_dir):
197197
copyfile(local_filepath, os.path.join(remote_output_dir, os.path.basename(local_filepath)))
198198
else:
199199
raise ValueError(f"The provided {remote_output_dir} isn't supported.")
@@ -611,8 +611,8 @@ def _upload_index(self, remote_output_dir: str, cache_dir: str, num_nodes: int,
611611
local_filepath = os.path.join(cache_dir, _INDEX_FILENAME)
612612

613613
if obj.scheme == "s3":
614-
s3 = _get_s3_client()
615-
s3.upload_file(
614+
s3 = S3Client()
615+
s3.client.upload_file(
616616
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
617617
)
618618
elif os.path.isdir(remote_output_dir):
@@ -775,6 +775,7 @@ def run(self, data_recipe: DataRecipe) -> None:
775775
print("Workers are ready ! Starting data processing...")
776776

777777
current_total = 0
778+
has_failed = False
778779
with tqdm(total=num_items, smoothing=0, position=-1, mininterval=1) as pbar:
779780
while True:
780781
try:
@@ -788,15 +789,20 @@ def run(self, data_recipe: DataRecipe) -> None:
788789
continue
789790
self.workers_tracker[index] = counter
790791
new_total = sum(self.workers_tracker.values())
792+
791793
pbar.update(new_total - current_total)
792794
current_total = new_total
793795
if current_total == num_items:
794796
break
795797

796-
num_nodes = _get_num_nodes()
798+
# Exit early if all the workers are done.
799+
# This means there were some kinda of errors.
800+
if all(not w.is_alive() for w in self.workers):
801+
has_failed = True
802+
break
797803

798804
# TODO: Understand why it hangs.
799-
if num_nodes == 1:
805+
if _get_num_nodes() == 1:
800806
for w in self.workers:
801807
w.join(0)
802808

@@ -806,6 +812,10 @@ def run(self, data_recipe: DataRecipe) -> None:
806812
data_recipe._done(self.delete_cached_files, self.remote_output_dir)
807813
print("Finished data processing!")
808814

815+
# TODO: Understand why it is required to avoid long shutdown.
816+
if _get_num_nodes() > 1:
817+
os._exit(int(has_failed))
818+
809819
def _exit_on_error(self, error: str) -> None:
810820
for w in self.workers:
811821
w.join(0)

src/lightning/data/streaming/downloader.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from typing import Any, Dict, List, Type
1717
from urllib import parse
1818

19+
from lightning.data.streaming.client import S3Client
20+
1921

2022
class Downloader(ABC):
2123
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
@@ -37,25 +39,20 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
3739
class S3Downloader(Downloader):
3840
@classmethod
3941
def download_file(cls, remote_filepath: str, local_filepath: str) -> None:
40-
import boto3
41-
from boto3.s3.transfer import TransferConfig
42-
from botocore.config import Config
43-
4442
obj = parse.urlparse(remote_filepath)
4543

4644
if obj.scheme != "s3":
4745
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")
4846

47+
# TODO: Add caching to avoid re-creating it
48+
s3 = S3Client()
49+
50+
from boto3.s3.transfer import TransferConfig
51+
4952
extra_args: Dict[str, Any] = {}
5053

51-
# Create a new session per thread
52-
session = boto3.session.Session()
53-
# Create a resource client using a thread's session object
54-
s3 = session.client("s3", config=Config(read_timeout=None))
55-
# Threads calling S3 operations return RuntimeError (cannot schedule new futures after
56-
# interpreter shutdown). Temporary solution is to have `use_threads` as `False`.
5754
# Issue: https://github.com/boto/boto3/issues/3113
58-
s3.download_file(
55+
s3.client.download_file(
5956
obj.netloc,
6057
obj.path.lstrip("/"),
6158
local_filepath,

src/lightning/data/streaming/functions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def map(
9292
num_nodes: Optional[int] = None,
9393
machine: Optional[str] = None,
9494
input_dir: Optional[str] = None,
95+
num_downloaders: int = 1,
9596
) -> None:
9697
"""This function map a callbable over a collection of files possibly in a distributed way.
9798
@@ -104,6 +105,7 @@ def map(
104105
fast_dev_run: Whether to use process only a sub part of the inputs
105106
num_nodes: When doing remote execution, the number of nodes to use.
106107
machine: When doing remote execution, the machine to use.
108+
num_downloaders: The number of downloaders per worker.
107109
108110
"""
109111
if not isinstance(inputs, Sequence):
@@ -127,6 +129,7 @@ def map(
127129
fast_dev_run=fast_dev_run,
128130
version=None,
129131
input_dir=input_dir or _get_input_dir(inputs),
132+
num_downloaders=num_downloaders,
130133
)
131134
return data_processor.run(LambdaDataTransformRecipe(fn, inputs))
132135
return _execute(
@@ -149,6 +152,7 @@ def optimize(
149152
num_nodes: Optional[int] = None,
150153
machine: Optional[str] = None,
151154
input_dir: Optional[str] = None,
155+
num_downloaders: int = 1,
152156
) -> None:
153157
"""This function converts a dataset into chunks possibly in a distributed way.
154158
@@ -164,6 +168,7 @@ def optimize(
164168
fast_dev_run: Whether to use process only a sub part of the inputs
165169
num_nodes: When doing remote execution, the number of nodes to use.
166170
machine: When doing remote execution, the machine to use.
171+
num_downloaders: The number of downloaders per worker.
167172
168173
"""
169174
if not isinstance(inputs, Sequence):
@@ -190,6 +195,7 @@ def optimize(
190195
remote_output_dir=PrettyDirectory(output_dir, remote_output_dir),
191196
fast_dev_run=fast_dev_run,
192197
input_dir=input_dir or _get_input_dir(inputs),
198+
num_downloaders=num_downloaders,
193199
)
194200
return data_processor.run(
195201
LambdaDataChunkRecipe(

src/lightning/data/streaming/map.py

Lines changed: 0 additions & 50 deletions
This file was deleted.
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from time import sleep, time
2+
from unittest import mock
3+
4+
from lightning.data.streaming import client
5+
6+
7+
def test_s3_client_without_cloud_space_id(monkeypatch):
8+
boto3 = mock.MagicMock()
9+
monkeypatch.setattr(client, "boto3", boto3)
10+
11+
botocore = mock.MagicMock()
12+
monkeypatch.setattr(client, "botocore", botocore)
13+
14+
s3 = client.S3Client(1)
15+
assert s3.client
16+
assert s3.client
17+
assert s3.client
18+
assert s3.client
19+
assert s3.client
20+
21+
boto3.client.assert_called_once()
22+
23+
24+
def test_s3_client_with_cloud_space_id(monkeypatch):
25+
boto3 = mock.MagicMock()
26+
monkeypatch.setattr(client, "boto3", boto3)
27+
28+
botocore = mock.MagicMock()
29+
monkeypatch.setattr(client, "botocore", botocore)
30+
31+
instance_metadata_provider = mock.MagicMock()
32+
monkeypatch.setattr(client, "InstanceMetadataProvider", instance_metadata_provider)
33+
34+
instance_metadata_fetcher = mock.MagicMock()
35+
monkeypatch.setattr(client, "InstanceMetadataFetcher", instance_metadata_fetcher)
36+
37+
monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "dummy")
38+
39+
s3 = client.S3Client(1)
40+
assert s3.client
41+
assert s3.client
42+
boto3.client.assert_called_once()
43+
sleep(1 - (time() - s3._last_time))
44+
assert s3.client
45+
assert s3.client
46+
assert len(boto3.client._mock_mock_calls) == 6
47+
sleep(1 - (time() - s3._last_time))
48+
assert s3.client
49+
assert s3.client
50+
assert len(boto3.client._mock_mock_calls) == 9

0 commit comments

Comments
 (0)