Skip to content

Commit af7e79a

Browse files
authored
Data Processing: Tiny optimization (#19389)
1 parent 6296a4f commit af7e79a

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

src/lightning/data/streaming/data_processor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
1717
from urllib import parse
1818

19+
import numpy as np
1920
from tqdm.auto import tqdm as _tqdm
2021

2122
from lightning import seed_everything
@@ -290,7 +291,7 @@ def _map_items_to_workers_weighted(
290291
else:
291292
print(f"Worker {worker_id} gets ({len(worker_items[worker_id])}) items for a total weight of {size}.")
292293

293-
return [worker_items[worker_id] for worker_id in worker_ids_this_node]
294+
return [np.random.permutation(worker_items[worker_id]).tolist() for worker_id in worker_ids_this_node]
294295

295296

296297
def _get_num_bytes(item: Any, base_path: str) -> int:

tests/tests_data/streaming/test_data_processor.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -255,50 +255,52 @@ def test_cache_dir_cleanup(tmpdir, monkeypatch):
255255

256256

257257
def test_map_items_to_workers_weighted(monkeypatch):
258+
seed_everything(42)
259+
258260
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
259-
assert workers_user_items == [list(range(5))]
261+
assert workers_user_items == [[1, 4, 2, 0, 3]]
260262
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
261-
assert workers_user_items == [[0, 2, 4], [1, 3]]
263+
assert workers_user_items == [[2, 4, 0], [3, 1]]
262264
workers_user_items = _map_items_to_workers_weighted(3, list(range(5)))
263-
assert workers_user_items == [[0, 3], [1, 4], [2]]
265+
assert workers_user_items == [[0, 3], [4, 1], [2]]
264266
workers_user_items = _map_items_to_workers_weighted(4, list(range(5)))
265-
assert workers_user_items == [[0, 4], [1], [2], [3]]
267+
assert workers_user_items == [[4, 0], [1], [2], [3]]
266268

267269
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
268270
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
269271
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
270-
assert workers_user_items == [[0, 2, 4]]
272+
assert workers_user_items == [[2, 0, 4]]
271273
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
272274
assert workers_user_items == [[0, 4], [1]]
273275

274276
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "2")
275277
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "1")
276278
workers_user_items = _map_items_to_workers_weighted(1, list(range(5)))
277-
assert workers_user_items == [[1, 3]]
279+
assert workers_user_items == [[3, 1]]
278280
workers_user_items = _map_items_to_workers_weighted(2, list(range(5)))
279281
assert workers_user_items == [[2], [3]]
280282

281283
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
282284
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "0")
283285
workers_user_items = _map_items_to_workers_weighted(1, list(range(32)))
284-
assert workers_user_items == [[0, 4, 8, 12, 16, 20, 24, 28]]
286+
assert workers_user_items == [[0, 24, 28, 4, 16, 20, 8, 12]]
285287
workers_user_items = _map_items_to_workers_weighted(2, list(range(32)))
286-
assert workers_user_items == [[0, 8, 16, 24], [1, 9, 17, 25]]
288+
assert workers_user_items == [[24, 16, 0, 8], [1, 17, 9, 25]]
287289
workers_user_items = _map_items_to_workers_weighted(3, list(range(32)))
288-
assert workers_user_items == [[0, 12, 24], [1, 13, 25], [2, 14, 26]]
290+
assert workers_user_items == [[24, 12, 0], [13, 25, 1], [14, 2, 26]]
289291
workers_user_items = _map_items_to_workers_weighted(4, list(range(32)))
290-
assert workers_user_items == [[0, 16], [1, 17], [2, 18], [3, 19]]
292+
assert workers_user_items == [[16, 0], [1, 17], [2, 18], [3, 19]]
291293

292294
monkeypatch.setenv("DATA_OPTIMIZER_NUM_NODES", "4")
293295
monkeypatch.setenv("DATA_OPTIMIZER_NODE_RANK", "3")
294296
workers_user_items = _map_items_to_workers_weighted(1, list(range(32)))
295-
assert workers_user_items == [[3, 7, 11, 15, 19, 23, 27, 31]]
297+
assert workers_user_items == [[3, 7, 19, 31, 11, 23, 27, 15]]
296298
workers_user_items = _map_items_to_workers_weighted(2, list(range(32)))
297-
assert workers_user_items == [[6, 14, 22, 30], [7, 15, 23, 31]]
299+
assert workers_user_items == [[14, 22, 6, 30], [15, 31, 23, 7]]
298300
workers_user_items = _map_items_to_workers_weighted(3, list(range(32)))
299-
assert workers_user_items == [[9, 21], [10, 22], [11, 23]]
301+
assert workers_user_items == [[21, 9], [22, 10], [23, 11]]
300302
workers_user_items = _map_items_to_workers_weighted(4, list(range(32)))
301-
assert workers_user_items == [[12, 28], [13, 29], [14, 30], [15, 31]]
303+
assert workers_user_items == [[12, 28], [13, 29], [30, 14], [15, 31]]
302304

303305

304306
def test_map_items_to_workers_sequentially(monkeypatch):

0 commit comments

Comments
 (0)