Skip to content

Commit 1284713

Browse files
tchatonthomas
andauthored
lightning.data: Remove torch distributed for the Dataset Optimizer (#19182)
Co-authored-by: thomas <[email protected]>
1 parent 0a5cca6 commit 1284713

File tree

5 files changed

+186
-37
lines changed

5 files changed

+186
-37
lines changed

src/lightning/app/utilities/network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,9 +191,9 @@ def get(self, path: str):
191191
return self.session.get(url)
192192

193193
@_http_method_logger_wrapper
194-
def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None):
194+
def post(self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None):
195195
url = urljoin(self.base_url, path)
196-
return self.session.post(url, data=data, params=query_params)
196+
return self.session.post(url, data=data, params=query_params, json=json)
197197

198198
@_http_method_logger_wrapper
199199
def delete(self, path: str):

src/lightning/data/streaming/data_processor.py

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
1515
from urllib import parse
1616

17-
import torch
1817
from tqdm.auto import tqdm as _tqdm
1918

2019
from lightning import seed_everything
@@ -28,14 +27,8 @@
2827
_LIGHTNING_CLOUD_LATEST,
2928
_TORCH_GREATER_EQUAL_2_1_0,
3029
)
30+
from lightning.data.utilities.broadcast import broadcast_object
3131
from lightning.data.utilities.packing import _pack_greedily
32-
from lightning.fabric.accelerators.cuda import is_cuda_available
33-
from lightning.fabric.plugins.environments import LightningEnvironment
34-
from lightning.fabric.utilities.distributed import (
35-
_distributed_is_initialized,
36-
_init_dist_connection,
37-
)
38-
from lightning.fabric.utilities.distributed import group as _group
3932

4033
if _TORCH_GREATER_EQUAL_2_1_0:
4134
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads
@@ -785,11 +778,11 @@ def __init__(
785778
self.reorder_files = reorder_files
786779

787780
# Ensure the input dir is the same across all nodes
788-
self.input_dir = self._broadcast_object(self.input_dir)
781+
self.input_dir = broadcast_object("input_dir", self.input_dir)
789782

790783
if self.output_dir:
791784
# Ensure the output dir is the same across all nodes
792-
self.output_dir = self._broadcast_object(self.output_dir)
785+
self.output_dir = broadcast_object("output_dir", self.output_dir)
793786
print(f"Storing the files under {self.output_dir.path}")
794787

795788
self.random_seed = random_seed
@@ -971,17 +964,3 @@ def _cleanup_cache(self) -> None:
971964
shutil.rmtree(cache_data_dir, ignore_errors=True)
972965

973966
os.makedirs(cache_data_dir, exist_ok=True)
974-
975-
def _broadcast_object(self, obj: Any) -> Any:
976-
"""Enable to synchronize an object across machines using torch.distributed.collectives."""
977-
num_nodes = _get_num_nodes()
978-
if num_nodes == 1:
979-
return obj
980-
981-
if not _distributed_is_initialized():
982-
process_group_backend = "nccl" if is_cuda_available() else "gloo"
983-
_init_dist_connection(LightningEnvironment(), process_group_backend, _get_node_rank(), num_nodes)
984-
985-
obj = [obj]
986-
torch.distributed.broadcast_object_list(obj, 0, group=_group.WORLD)
987-
return obj[0]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright The Lightning AI team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
import json
15+
import os
16+
import pickle
17+
from logging import Logger
18+
from typing import Any, Callable, Dict, Optional
19+
from urllib.parse import urljoin
20+
21+
import requests
22+
import urllib3
23+
24+
# for backwards compatibility
25+
from requests.adapters import HTTPAdapter
26+
from urllib3.util.retry import Retry
27+
28+
logger = Logger(__name__)
29+
30+
_CONNECTION_RETRY_TOTAL = 2880
31+
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
32+
_DEFAULT_REQUEST_TIMEOUT = 30 # seconds
33+
34+
35+
class _CustomRetryAdapter(HTTPAdapter):
36+
def __init__(self, *args: Any, **kwargs: Any) -> None:
37+
self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
38+
super().__init__(*args, **kwargs)
39+
40+
def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
41+
kwargs["timeout"] = kwargs.get("timeout", self.timeout)
42+
return super().send(request, **kwargs)
43+
44+
45+
def _response(r: Any, *args: Any, **kwargs: Any) -> Any:
46+
return r.raise_for_status()
47+
48+
49+
class _HTTPClient:
50+
"""A wrapper class around the requests library which handles chores like logging, retries, and timeouts
51+
automatically."""
52+
53+
def __init__(
54+
self,
55+
base_url: str,
56+
auth_token: Optional[str] = None,
57+
log_callback: Optional[Callable] = None,
58+
use_retry: bool = True,
59+
) -> None:
60+
self.base_url = base_url
61+
retry_strategy = Retry(
62+
# wait time between retries increases exponentially according to: backoff_factor * (2 ** (retry - 1))
63+
# but the the maximum wait time is 120 secs. By setting a large value (2880), we'll make sure clients
64+
# are going to be alive for a very long time (~ 4 days) but retries every 120 seconds
65+
total=_CONNECTION_RETRY_TOTAL,
66+
backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
67+
status_forcelist=[
68+
408, # Request Timeout
69+
429, # Too Many Requests
70+
500, # Internal Server Error
71+
502, # Bad Gateway
72+
503, # Service Unavailable
73+
504, # Gateway Timeout
74+
],
75+
)
76+
adapter = _CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
77+
self.session = requests.Session()
78+
79+
self.session.hooks = {"response": _response}
80+
81+
if use_retry:
82+
self.session.mount("http://", adapter)
83+
self.session.mount("https://", adapter)
84+
85+
if auth_token:
86+
self.session.headers.update({"Authorization": f"Bearer {auth_token}"})
87+
88+
def get(self, path: str) -> Any:
89+
url = urljoin(self.base_url, path)
90+
return self.session.get(url)
91+
92+
def post(
93+
self, path: str, *, query_params: Optional[Dict] = None, data: Optional[bytes] = None, json: Any = None
94+
) -> Any:
95+
url = urljoin(self.base_url, path)
96+
return self.session.post(url, data=data, params=query_params, json=json)
97+
98+
def delete(self, path: str) -> Any:
99+
url = urljoin(self.base_url, path)
100+
return self.session.delete(url)
101+
102+
103+
class _ImmutableDistributedMap:
104+
"""The _ImmutableDistributedMap enables to create a distributed key value pair in the cloud.
105+
106+
The first process to perform the set operation defines its value.
107+
108+
"""
109+
110+
def __init__(self) -> None:
111+
token = _get_token()
112+
113+
lightning_app_external_url = os.getenv("LIGHTNING_APP_EXTERNAL_URL")
114+
if lightning_app_external_url is None:
115+
raise RuntimeError("The `LIGHTNING_APP_EXTERNAL_URL` should be set.")
116+
117+
self.public_client: _HTTPClient = _HTTPClient(lightning_app_external_url, auth_token=token, use_retry=False)
118+
119+
lightning_app_state_url = os.getenv("LIGHTNING_APP_STATE_URL")
120+
if lightning_app_state_url is None:
121+
raise RuntimeError("The `LIGHTNING_APP_STATE_URL` should be set.")
122+
123+
self.private_client: _HTTPClient = _HTTPClient(lightning_app_state_url, auth_token=token, use_retry=False)
124+
125+
def set_and_get(self, key: str, value: Any) -> Any:
126+
payload = {"key": key, "value": pickle.dumps(value, 0).decode()}
127+
128+
# Try the public address first
129+
try:
130+
resp = self.public_client.post("/broadcast", json=payload)
131+
except (requests.exceptions.ConnectionError, urllib3.exceptions.MaxRetryError):
132+
# fallback to the private one
133+
resp = self.private_client.post("/broadcast", json=payload)
134+
135+
if resp.status_code != 200:
136+
raise RuntimeError(f"Failed to broadcast the following {key=} {value=}.")
137+
return pickle.loads(bytes(resp.json()["value"], "utf-8"))
138+
139+
140+
def broadcast_object(key: str, obj: Any) -> Any:
141+
"""This function enables to broadcast object across machines."""
142+
if os.getenv("LIGHTNING_APP_EXTERNAL_URL") is not None:
143+
return _ImmutableDistributedMap().set_and_get(key, obj)
144+
return obj
145+
146+
147+
def _get_token() -> Optional[str]:
148+
"""This function tries to retrieve a temporary token."""
149+
if os.getenv("LIGHTNING_CLOUD_URL") is None:
150+
return None
151+
152+
payload = {"apiKey": os.getenv("LIGHTNING_API_KEY"), "username": os.getenv("LIGHTNING_USERNAME")}
153+
url_login = os.getenv("LIGHTNING_CLOUD_URL", "") + "/v1/auth/login"
154+
res = requests.post(url_login, data=json.dumps(payload))
155+
if "token" not in res.json():
156+
raise RuntimeError(
157+
f"You haven't properly setup your environment variables with {url_login} and data: \n{payload}"
158+
)
159+
return res.json()["token"]

tests/tests_data/streaming/test_data_processor.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,17 +233,6 @@ def fn(*_, **__):
233233
_wait_for_file_to_exist(s3, obj, sleep_time=0.01)
234234

235235

236-
def test_broadcast_object(tmpdir, monkeypatch):
237-
data_processor = DataProcessor(input_dir=str(tmpdir))
238-
assert data_processor._broadcast_object("dummy") == "dummy"
239-
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
240-
monkeypatch.setattr(data_processor_module, "_distributed_is_initialized", lambda: True)
241-
torch_mock = mock.MagicMock()
242-
monkeypatch.setattr(data_processor_module, "torch", torch_mock)
243-
assert data_processor._broadcast_object("dummy") == "dummy"
244-
assert torch_mock.distributed.broadcast_object_list._mock_call_args.args == (["dummy"], 0)
245-
246-
247236
def test_cache_dir_cleanup(tmpdir, monkeypatch):
248237
cache_dir = os.path.join(tmpdir, "chunks")
249238
cache_data_dir = os.path.join(tmpdir, "data")
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
from unittest import mock
3+
4+
from lightning.data.utilities.broadcast import broadcast_object, requests
5+
6+
7+
@mock.patch.dict(
8+
os.environ, {"LIGHTNING_APP_EXTERNAL_URL": "http://", "LIGHTNING_APP_STATE_URL": "http://"}, clear=True
9+
)
10+
def test_broadcast(monkeypatch):
11+
session = mock.MagicMock()
12+
resp = requests.Response()
13+
resp.status_code = 200
14+
15+
def fn(*args, **kwargs):
16+
nonlocal session
17+
return {"value": session.post._mock_call_args_list[0].kwargs["json"]["value"]}
18+
19+
resp.json = fn
20+
session.post.return_value = resp
21+
monkeypatch.setattr(requests, "Session", mock.MagicMock(return_value=session))
22+
assert broadcast_object("key", "value") == "value"

0 commit comments

Comments
 (0)