Skip to content

Commit afbae75

Browse files
authored
Run Modal tests on S3 using obstore (#794)
1 parent d2c34a5 commit afbae75

File tree

8 files changed

+62
-39
lines changed

8 files changed

+62
-39
lines changed

.github/workflows/modal-tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@ jobs:
3434
- name: Install
3535
run: |
3636
python -m pip install --upgrade pip
37-
python -m pip install -e '.[test,modal]'
37+
python -m pip install -e '.[test]' modal obstore
3838
3939
- name: Run tests
4040
run: |
4141
pytest -vs -k "test_modal.py or modal" --runcloud
4242
env:
4343
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
4444
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
45+
AWS_REGION: us-east-1
4546
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
4647
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}

cubed/runtime/executors/modal.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,12 @@
4848
[
4949
"array-api-compat",
5050
"donfig",
51-
"fsspec",
5251
"mypy_extensions", # for rechunker
5352
"ndindex",
5453
"networkx",
54+
"obstore",
5555
"psutil",
5656
"pytest-mock", # TODO: only needed for tests
57-
"s3fs",
5857
"tenacity",
5958
"toolz",
6059
"zarr",

cubed/tests/runtime/test_modal.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@
2525
[
2626
"array-api-compat",
2727
"donfig",
28-
"fsspec",
2928
"mypy_extensions", # for rechunker
3029
"ndindex",
3130
"networkx",
31+
"obstore",
3232
"psutil",
3333
"pytest-mock", # TODO: only needed for tests
34-
"s3fs",
3534
"tenacity",
3635
"toolz",
3736
"zarr",
@@ -41,7 +40,7 @@
4140

4241
@app.function(
4342
image=image,
44-
secrets=[modal.Secret.from_name("my-aws-secret")],
43+
secrets=[modal.Secret.from_name("aws-secret-us-east-1")],
4544
retries=2,
4645
timeout=10,
4746
cloud="aws",
@@ -53,7 +52,7 @@ def deterministic_failure_modal(i, path=None, timing_map=None, *, name=None):
5352

5453
@app.function(
5554
image=image,
56-
secrets=[modal.Secret.from_name("my-aws-secret")],
55+
secrets=[modal.Secret.from_name("aws-secret-us-east-1")],
5756
timeout=10,
5857
cloud="aws",
5958
region=region,
@@ -64,7 +63,7 @@ def deterministic_failure_modal_no_retries(i, path=None, timing_map=None, *, nam
6463

6564
@app.function(
6665
image=image,
67-
secrets=[modal.Secret.from_name("my-aws-secret")],
66+
secrets=[modal.Secret.from_name("aws-secret-us-east-1")],
6867
retries=2,
6968
timeout=300,
7069
cloud="aws",

cubed/tests/runtime/utils.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,27 @@
1-
import re
1+
import asyncio
22
import time
3-
from urllib.parse import urlparse
3+
from pathlib import Path
44

5-
import fsspec
5+
import obstore as obs
66

7-
from cubed.utils import join_path
7+
8+
def path_to_store(path):
9+
if isinstance(path, str):
10+
if "://" not in path:
11+
return obs.store.from_url(Path(path).as_uri(), mkdir=True)
12+
else:
13+
return obs.store.from_url(path)
14+
elif isinstance(path, Path):
15+
return obs.store.from_url(path.as_uri(), mkdir=True)
816

917

10-
def read_int_from_file(path):
11-
with fsspec.open(path) as f:
12-
return int(f.read())
18+
def read_int_from_file(store, path):
19+
result = obs.get(store, path)
20+
return int(result.bytes())
1321

1422

15-
def write_int_to_file(path, i):
16-
with fsspec.open(path, "w") as f:
17-
f.write(str(i))
23+
def write_int_to_file(store, path, i):
24+
obs.put(store, path, bytes(str(i), encoding="UTF8"))
1825

1926

2027
def deterministic_failure(path, timing_map, i, *, default_sleep=0.01, name=None):
@@ -34,13 +41,12 @@ def deterministic_failure(path, timing_map, i, *, default_sleep=0.01, name=None)
3441
they will all run normally.
3542
"""
3643
# increment number of invocations of this function with arg i
37-
invocation_count_file = join_path(path, f"{i}")
38-
fs = fsspec.open(invocation_count_file).fs
39-
if fs.exists(invocation_count_file):
40-
invocation_count = read_int_from_file(invocation_count_file)
41-
else:
44+
store = path_to_store(path)
45+
try:
46+
invocation_count = read_int_from_file(store, f"{i}")
47+
except FileNotFoundError:
4248
invocation_count = 0
43-
write_int_to_file(invocation_count_file, invocation_count + 1)
49+
write_int_to_file(store, f"{i}", invocation_count + 1)
4450

4551
timing_code = default_sleep
4652
if i in timing_map:
@@ -62,6 +68,20 @@ def deterministic_failure(path, timing_map, i, *, default_sleep=0.01, name=None)
6268

6369
def check_invocation_counts(
6470
path, timing_map, n_tasks, retries=None, expected_invocation_counts_overrides=None
71+
):
72+
asyncio.run(
73+
check_invocation_counts_async(
74+
path,
75+
timing_map,
76+
n_tasks,
77+
retries=retries,
78+
expected_invocation_counts_overrides=expected_invocation_counts_overrides,
79+
)
80+
)
81+
82+
83+
async def check_invocation_counts_async(
84+
path, timing_map, n_tasks, retries=None, expected_invocation_counts_overrides=None
6585
):
6686
expected_invocation_counts = {}
6787
for i in range(n_tasks):
@@ -84,16 +104,11 @@ def check_invocation_counts(
84104
expected_invocation_counts.update(expected_invocation_counts_overrides)
85105

86106
# retrieve outputs concurrently, so we can test on large numbers of inputs
87-
# see https://filesystem-spec.readthedocs.io/en/latest/async.html#synchronous-api
88-
if re.match(r"^[a-zA-Z]:\\", str(path)): # Windows local file
89-
protocol = ""
90-
else:
91-
protocol = urlparse(str(path)).scheme
92-
fs = fsspec.filesystem(protocol)
93-
paths = [join_path(path, str(i)) for i in range(n_tasks)]
94-
out = fs.cat(paths)
95-
path_to_i = lambda p: int(p.rsplit("/", 1)[-1])
96-
actual_invocation_counts = {path_to_i(path): int(val) for path, val in out.items()}
107+
store = path_to_store(path)
108+
paths = [str(i) for i in range(n_tasks)]
109+
results = await asyncio.gather(*[obs.get_async(store, path) for path in paths])
110+
values = await asyncio.gather(*[result.bytes_async() for result in results])
111+
actual_invocation_counts = {i: int(val) for i, val in enumerate(values)}
97112

98113
if actual_invocation_counts != expected_invocation_counts:
99114
for i, expected_count in expected_invocation_counts.items():

cubed/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def test_matmul_cloud(executor):
452452
@pytest.mark.cloud
453453
def test_matmul_modal(modal_executor):
454454
tmp_path = "s3://cubed-unittest/matmul"
455-
spec = cubed.Spec(tmp_path, allowed_mem=100000)
455+
spec = cubed.Spec(tmp_path, allowed_mem=100000, storage_options=dict(use_obstore=True))
456456

457457
a = xp.asarray(
458458
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],

cubed/tests/test_executor_features.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def test_rich_progress_bar(spec, executor):
150150
def test_callbacks_modal(spec, modal_executor):
151151
task_counter = TaskCounter(check_timestamps=False)
152152
tmp_path = "s3://cubed-unittest/callbacks"
153-
spec = cubed.Spec(tmp_path, allowed_mem=100000)
153+
spec = cubed.Spec(
154+
tmp_path, allowed_mem=100000, storage_options=dict(use_obstore=True)
155+
)
154156

155157
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
156158
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
@@ -241,7 +243,9 @@ def test_compute_arrays_in_parallel(spec, any_executor, compute_arrays_in_parall
241243
@pytest.mark.parametrize("compute_arrays_in_parallel", [True, False])
242244
def test_compute_arrays_in_parallel_modal(modal_executor, compute_arrays_in_parallel):
243245
tmp_path = "s3://cubed-unittest/parallel_pipelines"
244-
spec = cubed.Spec(tmp_path, allowed_mem=100000)
246+
spec = cubed.Spec(
247+
tmp_path, allowed_mem=100000, storage_options=dict(use_obstore=True)
248+
)
245249

246250
a = cubed.random.random((10, 10), chunks=(5, 5), spec=spec)
247251
b = cubed.random.random((10, 10), chunks=(5, 5), spec=spec)
@@ -290,7 +294,11 @@ def test_check_runtime_memory_dask_no_workers(spec, executor):
290294
@pytest.mark.cloud
291295
def test_check_runtime_memory_modal(spec, modal_executor):
292296
tmp_path = "s3://cubed-unittest/check-runtime-memory"
293-
spec = cubed.Spec(tmp_path, allowed_mem="4GB") # larger than Modal runtime memory
297+
spec = cubed.Spec(
298+
tmp_path,
299+
allowed_mem="4GB", # larger than Modal runtime memory
300+
storage_options=dict(use_obstore=True),
301+
)
294302
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
295303
b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec)
296304
c = xp.add(a, b)

cubed/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
{
7979
"spec.executor_options.cloud": "aws",
8080
"spec.executor_options.region": "us-east-1",
81+
"spec.executor_options.secret": "aws-secret-us-east-1",
8182
}
8283
)
8384
executor_options = dict(enable_output=True)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ coiled = [
8282
test = [
8383
"cubed[diagnostics]",
8484
"dill",
85-
"fsspec",
8685
"numpy_groupies",
86+
"obstore",
8787
"pytest",
8888
"pytest-cov",
8989
"pytest-mock",

0 commit comments

Comments
 (0)