Skip to content

Commit dcb580e

Browse files
rajdchakIsaevIlya
authored andcommitted
Integration tests for dcp (#257)
* Integration tests for dcp * Fix in tests for S3 express * Made formatting changes * Addressed comments * Addressed comments * Addressed comments * Addressed comments * Addressed comments * Fixed issues with copy * Fixed issue with port number * Fixed issue with port number
1 parent ae458f2 commit dcb580e

File tree

1 file changed

+234
-6
lines changed

1 file changed

+234
-6
lines changed

s3torchconnector/tst/e2e/dcp/test_e2e_s3_file_system.py

Lines changed: 234 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,247 @@
44
import pytest
55
import torch
66
import torch.distributed.checkpoint as dcp
7+
8+
import torch.distributed as dist
79
from torch.distributed.checkpoint import CheckpointException
10+
import torch.multiprocessing as mp
811

912
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
13+
from s3torchconnector._s3client import S3Client
14+
from s3torchconnector._s3dataset_common import parse_s3_uri
15+
import os
16+
import random
17+
18+
19+
def generate_random_port():
20+
return random.randint(1, 500)
21+
22+
23+
def setup(rank, world_size, port):
24+
os.environ["MASTER_ADDR"] = "localhost"
25+
os.environ["MASTER_PORT"] = port
26+
dist.init_process_group("gloo", rank=rank, world_size=world_size)
27+
28+
29+
def cleanup():
30+
dist.destroy_process_group()
31+
32+
33+
def run(
34+
rank,
35+
world_size,
36+
threads,
37+
region,
38+
s3_path_s3storagewriter,
39+
test_data,
40+
port,
41+
):
42+
print(f"Running on rank {rank}.")
43+
44+
setup(rank, world_size, port)
45+
# Save using S3StorageWriter
46+
dcp_save(
47+
test_data,
48+
S3StorageWriter(
49+
region=region,
50+
thread_count=threads,
51+
path=s3_path_s3storagewriter,
52+
overwrite=True,
53+
),
54+
)
55+
56+
cleanup()
57+
58+
59+
def multi_process_dcp_save_load(
60+
world_size, thread_count, checkpoint_directory, tensor_dimensions, port_offset
61+
):
62+
region = checkpoint_directory.region
63+
s3_path_s3storagewriter = f"{checkpoint_directory.s3_uri}checkpoint_s3storagewriter"
64+
s3_path_s3storagewriter = s3_path_s3storagewriter.replace("[", "_").replace(
65+
"]", "_"
66+
)
67+
68+
test_data = {
69+
"tensor1": torch.randn(tensor_dimensions),
70+
"tensor2": torch.randn(5, 5),
71+
"scalar": torch.tensor(3.14),
72+
}
73+
74+
port = str(generate_random_port() + port_offset)
75+
mp.spawn(
76+
run,
77+
args=(
78+
world_size,
79+
thread_count,
80+
region,
81+
s3_path_s3storagewriter,
82+
test_data,
83+
port,
84+
),
85+
nprocs=world_size,
86+
join=True,
87+
)
88+
89+
load_data(
90+
region,
91+
s3_path_s3storagewriter,
92+
test_data,
93+
world_size,
94+
thread_count,
95+
)
96+
97+
98+
def dcp_save(data, writer):
99+
dcp.save(
100+
data,
101+
storage_writer=writer,
102+
)
103+
104+
105+
def dcp_load(loaded_data, reader):
106+
dcp.load(
107+
loaded_data,
108+
storage_reader=reader,
109+
)
110+
111+
112+
def load_data(region, s3_path_s3storagewriter, test_data, world_size, thread_count):
113+
s3_client = S3Client(region=region)
114+
bucket, key = parse_s3_uri(s3_path_s3storagewriter)
115+
list_result_s3storagewriter = list(s3_client.list_objects(bucket, f"{key}/"))
116+
117+
# Compare length
118+
assert list_result_s3storagewriter is not None
119+
assert (
120+
len(list_result_s3storagewriter[0].object_info) == world_size * thread_count + 1
121+
)
122+
123+
# Load using S3StorageReader
124+
loaded_data_s3storagereader = {}
125+
dcp_load(
126+
loaded_data_s3storagereader,
127+
S3StorageReader(
128+
region,
129+
s3_path_s3storagewriter,
130+
),
131+
)
132+
133+
for key in loaded_data_s3storagereader.keys():
134+
assert torch.allclose(
135+
loaded_data_s3storagereader[key], test_data[key]
136+
), f"S3StorageReader: Loaded tensor for key '{key}' does not match original"
137+
138+
print("Test passed: Saved and loaded data correctly.")
139+
140+
141+
@pytest.mark.parametrize(
142+
"tensor_dimensions, thread_count, port_offset",
143+
[
144+
([3, 2], 1, 20000),
145+
([10, 1024, 1024], 1, 30000),
146+
([3, 2], 4, 40000),
147+
([10, 1024, 1024], 4, 50000),
148+
],
149+
ids=[
150+
"small_tensor_single_thread",
151+
"large_tensor_single_thread",
152+
"small_tensor_multi_thread",
153+
"large_tensor_multi_thread",
154+
],
155+
)
156+
def test_dcp_when_multi_process(
157+
checkpoint_directory, tensor_dimensions, thread_count, port_offset
158+
):
159+
multi_process_dcp_save_load(
160+
6, thread_count, checkpoint_directory, tensor_dimensions, port_offset
161+
)
162+
163+
164+
def test_dcp_save_non_existing_s3_uri(checkpoint_directory):
165+
t1 = torch.randn(10)
166+
region = checkpoint_directory.region
167+
non_existing_s3_uri = "s3://non-existing-bucket/checkpoint"
10168

169+
with pytest.raises(CheckpointException) as s3_excinfo:
170+
dcp_save(
171+
{"random": t1},
172+
S3StorageWriter(
173+
region,
174+
non_existing_s3_uri,
175+
overwrite=True,
176+
),
177+
)
178+
179+
assert isinstance(
180+
s3_excinfo.value, CheckpointException
181+
), "Using S3StorageWriter DCP should raise a CheckpointException"
182+
183+
print("Test passed: Raised CheckpointException.")
184+
185+
186+
def test_dcp_load_non_existing_s3_uri(checkpoint_directory):
187+
region = checkpoint_directory.region
188+
non_existing_s3_uri = "s3://non-existing-bucket/checkpoint"
189+
190+
with pytest.raises(CheckpointException) as s3_excinfo:
191+
dcp_load(
192+
{},
193+
S3StorageReader(
194+
region,
195+
non_existing_s3_uri,
196+
),
197+
)
11198

12-
def test_fsdp_filesystem_when_single_thread(checkpoint_directory):
13-
# TODO: implement me
14-
pass
199+
assert isinstance(
200+
s3_excinfo.value, CheckpointException
201+
), "Using S3StorageReader DCP should raise a CheckpointException"
15202

203+
print("Test passed: Raised CheckpointException.")
16204

17-
def test_fsdp_filesystem_when_multiple_threads(checkpoint_directory):
18-
# TODO: implement me
19-
pass
205+
206+
def test_successful_rename(checkpoint_directory):
207+
src_path = f"{checkpoint_directory.s3_uri}test_rename_src"
208+
test_data = {
209+
"tensor1": torch.randn(10, 10),
210+
"tensor2": torch.randn(5, 5),
211+
"scalar": torch.tensor(3.14),
212+
}
213+
region = checkpoint_directory.region
214+
215+
# Test S3StorageWriter
216+
s3_writer = S3StorageWriter(region, src_path, overwrite=False)
217+
dcp_save(test_data, s3_writer)
218+
s3_writer.fs.rename(f"{src_path}/.metadata", f"{src_path}/.metadata2")
219+
220+
assert not s3_writer.fs.exists(f"{src_path}/.metadata")
221+
assert s3_writer.fs.exists(f"{src_path}/.metadata2")
222+
223+
print("Test passed: Rename was successful.")
224+
225+
226+
def test_rename_non_existing_s3_uri(checkpoint_directory):
227+
region = checkpoint_directory.region
228+
non_existing_s3_uri = f"{checkpoint_directory.s3_uri}non-existing-object"
229+
storage_writer = S3StorageWriter(region, non_existing_s3_uri, overwrite=True)
230+
231+
with pytest.raises(Exception, match="Service error: The object was not found"):
232+
storage_writer.fs.rename(
233+
f"{non_existing_s3_uri}/.metadata", f"{non_existing_s3_uri}/.metadata2"
234+
)
235+
236+
print("Test passed: Raised object not found error.")
237+
238+
239+
def test_rm_file_non_existing_s3_uri(checkpoint_directory):
240+
region = checkpoint_directory.region
241+
non_existing_s3_uri = f"{checkpoint_directory.s3_uri}non-existing-object-hooo"
242+
storage_writer = S3StorageWriter(region, non_existing_s3_uri, overwrite=True)
243+
storage_writer.fs.rm_file(non_existing_s3_uri)
244+
245+
print(
246+
"Test passed: In case of delete did not throw error if the object was not found."
247+
)
20248

21249

22250
# Inspired from https://github.com/pytorch/pytorch/blob/main/test/distributed/checkpoint/test_fsspec.py.

0 commit comments

Comments
 (0)