Skip to content

Commit 1fa5819

Browse files
tchatonthomas
authored andcommitted
Resolve bug with the uploader (#18939)
Co-authored-by: thomas <[email protected]> (cherry picked from commit f9e82c6)
1 parent 1fe9f14 commit 1fa5819

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/lightning/data/streaming/data_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
191191
)
192192
except Exception as e:
193193
print(e)
194-
return
195194
if os.path.isdir(output_dir.path):
196195
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
197196
else:

tests/tests_data/streaming/test_data_processor.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,64 @@ def fn(*_, **__):
6363
assert os.listdir(remote_output_dir) == ["a.txt"]
6464

6565

66+
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
67+
def test_upload_s3_fn(tmpdir, monkeypatch):
68+
input_dir = os.path.join(tmpdir, "input_dir")
69+
os.makedirs(input_dir, exist_ok=True)
70+
71+
cache_dir = os.path.join(tmpdir, "cache_dir")
72+
os.makedirs(cache_dir, exist_ok=True)
73+
74+
remote_output_dir = os.path.join(tmpdir, "remote_output_dir")
75+
os.makedirs(remote_output_dir, exist_ok=True)
76+
77+
filepath = os.path.join(input_dir, "a.txt")
78+
79+
with open(filepath, "w") as f:
80+
f.write("HERE")
81+
82+
upload_queue = mock.MagicMock()
83+
84+
paths = [filepath, None]
85+
86+
def fn(*_, **__):
87+
value = paths.pop(0)
88+
if value is None:
89+
return value
90+
return value
91+
92+
upload_queue.get = fn
93+
94+
remove_queue = mock.MagicMock()
95+
96+
s3_client = mock.MagicMock()
97+
98+
called = False
99+
100+
def copy_file(local_filepath, *args):
101+
nonlocal called
102+
called = True
103+
from shutil import copyfile
104+
105+
copyfile(local_filepath, os.path.join(remote_output_dir.path, os.path.basename(local_filepath)))
106+
107+
s3_client.client.upload_file = copy_file
108+
109+
monkeypatch.setattr(data_processor_module, "S3Client", mock.MagicMock(return_value=s3_client))
110+
111+
assert os.listdir(remote_output_dir) == []
112+
113+
assert not called
114+
115+
_upload_fn(upload_queue, remove_queue, cache_dir, Dir(path=remote_output_dir, url="s3://url"))
116+
117+
assert called
118+
119+
assert len(paths) == 0
120+
121+
assert os.listdir(remote_output_dir) == ["a.txt"]
122+
123+
66124
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
67125
def test_remove_target(tmpdir):
68126
input_dir = os.path.join(tmpdir, "input_dir")

0 commit comments

Comments
 (0)