44import copy
55from typing_extensions import Generator
66import pandas as pd
7+ import itertools
78
8- from dask .distributed import (
9- as_completed ,
10- futures_of ,
11- CancelledError ,
12- fire_and_forget ,
13- )
9+ from dask .distributed import CancelledError
1410from distributed .client import _get_global_client as get_client
1511
16- from dask .delayed import Delayed , delayed
17- import dask .array as da
12+ from dask .delayed import delayed
1813import dask .bag as db
1914from dask .diagnostics import ProgressBar
20- from dask import persist , compute
15+ from dask import compute
2116
2217from .. import Extents , Storage , Data , ShatterConfig
2318from ..resources .taskgraph import Graph
@@ -39,14 +34,15 @@ def get_data(extents: Extents, filename: str, storage: Storage):
3934 data .execute ()
4035
4136 points = p .get_dataframe (0 )
42- points = points .loc [points .Y < extents .bounds .maxy ]
43- points = points .loc [points .Y >= extents .bounds .miny ]
44- points = points .loc [points .X >= extents .bounds .minx ]
45- points = points .loc [points .X < extents .bounds .maxx , [* attrs , 'xi' , 'yi' ]]
37+ points = (points
38+ .loc [points .Y < extents .bounds .maxy ]
39+ .loc [points .Y >= extents .bounds .miny ]
40+ .loc [points .X >= extents .bounds .minx ]
41+ .loc [points .X < extents .bounds .maxx , [* attrs , 'xi' , 'yi' ]])
4642
47- points .loc [:, 'xi' ] = da .floor (points .xi )
43+ points .loc [:, 'xi' ] = np .floor (points .xi )
4844 # ceil for y because origin is at top left
49- points .loc [:, 'yi' ] = da .ceil (points .yi )
45+ points .loc [:, 'yi' ] = np .ceil (points .yi )
5046 return points
5147
5248
@@ -74,10 +70,10 @@ def agg_list(data_in, proc_num):
7470
7571 coerced = data_in .astype (col_dtypes | xyi_dtypes )
7672 gb = coerced .groupby (['xi' , 'yi' ], sort = False )
77- listed = gb .agg (lambda x : np .array (x , old_dtypes [x .name ]))
7873 counts_df = gb [first_col_name ].agg ('count' ).rename ('count' )
79- listed = listed .join (counts_df )
80- listed = listed .assign (shatter_process_num = proc_num )
74+ listed = (gb .agg (lambda x : np .array (x , old_dtypes [x .name ]))
75+ .join (counts_df )
76+ .assign (shatter_process_num = proc_num ))
8177
8278 return listed
8379
@@ -159,45 +155,49 @@ def kill_gracefully(signum, frame):
159155
160156 signal .signal (signal .SIGINT , kill_gracefully )
161157
162- # leaf_bag: db.Bag = db.from_sequence(leaves)
163- # processes = leaf_bag.map(do_one, config, storage)
164- processes = [delayed (do_one )(leaf , config , storage ) for leaf in leaves ]
165-
166158 ## If dask is distributed, use the futures feature
167159 dc = get_client ()
168160 consolidate_count = 30
169- count = 0
170161 if dc is not None :
171- pc_futures = futures_of (persist (processes ))
172- for batch in as_completed (pc_futures , with_results = True ).batches ():
173- for _ , pack in batch :
174- if isinstance (pack , CancelledError ):
175- continue
176- if isinstance (pack , int ):
177- pack = [pack ]
178- for pc in pack :
179- if isinstance (pc , BaseException ):
180- config .log .warning ('Worker returned exception: ' , pc )
181- if isinstance (pc , int ):
182- count += 1
183- if count >= consolidate_count :
184- faf = dc .submit (
185- storage .consolidate_shatter ,
186- timestamp = config .timestamp ,
187- )
188- fire_and_forget (faf )
189- count = 0
190- config .point_count = config .point_count + pc
191- del pc
162+ processes = []
163+ count = 0
164+ for leaf_bunch in itertools .batched (leaves , consolidate_count ):
165+ count = count + 1
166+ processes .append (dc .map (do_one , leaf_bunch , config = config , storage = storage ))
167+
168+ processes .append (dc .submit (storage .consolidate_shatter , config .timestamp ))
169+ gathered = dc .gather (processes )
170+ point_count = 0
171+ for pc in gathered :
172+ if pc is None :
173+ continue
174+ if isinstance (pc , int ):
175+ point_count = point_count + pc
176+ elif isinstance (pc , BaseException ):
177+ config .log .warning (pc )
178+ elif isinstance (pc , CancelledError ):
179+ config .log .warning (pc )
180+ del pc
192181
193- end_time = datetime .datetime .now ().timestamp () * 1000
194- config .end_time = end_time
195- config .finished = True
196- point_count = config .point_count
197182 else :
198183 # Handle non-distributed dask scenarios
199184 with ProgressBar ():
200- point_count = sum (* compute (processes ))
185+ count = 0
186+ futures = []
187+ for leaf in leaves :
188+ count = count + 1
189+ futures .append (delayed (do_one )(leaf , config , storage ))
190+ if count % consolidate_count == 0 :
191+ futures .append (delayed (storage .consolidate_shatter )(timestamp = config .timestamp ))
192+
193+ results = compute (* futures )
194+ pcs = [possible_pc for possible_pc in results if possible_pc is not None ]
195+ point_count = sum (pcs )
196+
197+ end_time = datetime .datetime .now ().timestamp () * 1000
198+ config .end_time = end_time
199+ config .finished = True
200+ config .point_count = point_count
201201
202202 return point_count
203203
@@ -234,9 +234,9 @@ def shatter(config: ShatterConfig) -> int:
234234 if config .tile_size is not None :
235235 leaves = extents .get_leaf_children (config .tile_size )
236236 else :
237- chunks = extents .chunk (data , pc_threshold = 600000 )
238- leaves = db .from_sequence (chunks ).compute ()
237+ leaves = extents .chunk (data )
239238
239+ leaves = itertools .chain (leaves )
240240 # Begin main operations
241241 config .log .debug ('Fetching and arranging data...' )
242242 storage .save_shatter_meta (config )
0 commit comments