1
- import re
1
+ import asyncio
2
2
import time
3
- from urllib . parse import urlparse
3
+ from pathlib import Path
4
4
5
- import fsspec
5
+ import obstore as obs
6
6
7
- from cubed .utils import join_path
7
+
8
+ def path_to_store (path ):
9
+ if isinstance (path , str ):
10
+ if "://" not in path :
11
+ return obs .store .from_url (Path (path ).as_uri (), mkdir = True )
12
+ else :
13
+ return obs .store .from_url (path )
14
+ elif isinstance (path , Path ):
15
+ return obs .store .from_url (path .as_uri (), mkdir = True )
8
16
9
17
10
- def read_int_from_file (path ):
11
- with fsspec . open ( path ) as f :
12
- return int (f . read ())
18
+ def read_int_from_file (store , path ):
19
+ result = obs . get ( store , path )
20
+ return int (result . bytes ())
13
21
14
22
15
- def write_int_to_file (path , i ):
16
- with fsspec .open (path , "w" ) as f :
17
- f .write (str (i ))
23
+ def write_int_to_file (store , path , i ):
24
+ obs .put (store , path , bytes (str (i ), encoding = "UTF8" ))
18
25
19
26
20
27
def deterministic_failure (path , timing_map , i , * , default_sleep = 0.01 , name = None ):
@@ -34,13 +41,12 @@ def deterministic_failure(path, timing_map, i, *, default_sleep=0.01, name=None)
34
41
they will all run normally.
35
42
"""
36
43
# increment number of invocations of this function with arg i
37
- invocation_count_file = join_path (path , f"{ i } " )
38
- fs = fsspec .open (invocation_count_file ).fs
39
- if fs .exists (invocation_count_file ):
40
- invocation_count = read_int_from_file (invocation_count_file )
41
- else :
44
+ store = path_to_store (path )
45
+ try :
46
+ invocation_count = read_int_from_file (store , f"{ i } " )
47
+ except FileNotFoundError :
42
48
invocation_count = 0
43
- write_int_to_file (invocation_count_file , invocation_count + 1 )
49
+ write_int_to_file (store , f" { i } " , invocation_count + 1 )
44
50
45
51
timing_code = default_sleep
46
52
if i in timing_map :
@@ -62,6 +68,20 @@ def deterministic_failure(path, timing_map, i, *, default_sleep=0.01, name=None)
62
68
63
69
def check_invocation_counts (
64
70
path , timing_map , n_tasks , retries = None , expected_invocation_counts_overrides = None
71
+ ):
72
+ asyncio .run (
73
+ check_invocation_counts_async (
74
+ path ,
75
+ timing_map ,
76
+ n_tasks ,
77
+ retries = retries ,
78
+ expected_invocation_counts_overrides = expected_invocation_counts_overrides ,
79
+ )
80
+ )
81
+
82
+
83
+ async def check_invocation_counts_async (
84
+ path , timing_map , n_tasks , retries = None , expected_invocation_counts_overrides = None
65
85
):
66
86
expected_invocation_counts = {}
67
87
for i in range (n_tasks ):
@@ -84,16 +104,11 @@ def check_invocation_counts(
84
104
expected_invocation_counts .update (expected_invocation_counts_overrides )
85
105
86
106
# retrieve outputs concurrently, so we can test on large numbers of inputs
87
- # see https://filesystem-spec.readthedocs.io/en/latest/async.html#synchronous-api
88
- if re .match (r"^[a-zA-Z]:\\" , str (path )): # Windows local file
89
- protocol = ""
90
- else :
91
- protocol = urlparse (str (path )).scheme
92
- fs = fsspec .filesystem (protocol )
93
- paths = [join_path (path , str (i )) for i in range (n_tasks )]
94
- out = fs .cat (paths )
95
- path_to_i = lambda p : int (p .rsplit ("/" , 1 )[- 1 ])
96
- actual_invocation_counts = {path_to_i (path ): int (val ) for path , val in out .items ()}
107
+ store = path_to_store (path )
108
+ paths = [str (i ) for i in range (n_tasks )]
109
+ results = await asyncio .gather (* [obs .get_async (store , path ) for path in paths ])
110
+ values = await asyncio .gather (* [result .bytes_async () for result in results ])
111
+ actual_invocation_counts = {i : int (val ) for i , val in enumerate (values )}
97
112
98
113
if actual_invocation_counts != expected_invocation_counts :
99
114
for i , expected_count in expected_invocation_counts .items ():
0 commit comments