11from pathlib import Path
2- from itertools import chain
3-
42
3+ from dask .diagnostics import ProgressBar
4+ from distributed .client import _get_global_client as get_client
55from typing_extensions import Union
66from osgeo import gdal , osr
77import dask
88import numpy as np
99import pandas as pd
1010
1111
12- from .. import Storage , Extents , ExtractConfig
12+ from .. import Storage , Extents , ExtractConfig , Bounds , Graph
1313
1414np_to_gdal_types = {
1515 np .dtype (np .byte ).str : gdal .GDT_Byte ,
2727
2828
2929def write_tif (
30- xsize : int ,
31- ysize : int ,
30+ bounds : Bounds ,
3231 data : np .ndarray ,
3332 nan_val : float | int ,
3433 name : str ,
@@ -48,13 +47,14 @@ def write_tif(
4847 crs = config .crs
4948 srs = osr .SpatialReference ()
5049 srs .ImportFromWkt (crs .to_wkt ())
51- b = config .bounds
50+ minx , miny , maxx , maxy = bounds .get ()
51+ ysize , xsize = data .shape
5252
5353 transform = [
54- b . minx ,
54+ minx ,
5555 config .resolution ,
5656 0 ,
57- b . maxy ,
57+ maxy ,
5858 0 ,
5959 - 1 * config .resolution ,
6060 ]
@@ -70,8 +70,8 @@ def write_tif(
7070 )
7171 tif .SetGeoTransform (transform )
7272 tif .SetProjection (srs .ExportToWkt ())
73- tif .GetRasterBand (1 ).WriteArray (data )
7473 tif .GetRasterBand (1 ).SetNoDataValue (nan_val )
74+ tif .GetRasterBand (1 ).WriteArray (data )
7575 tif .FlushCache ()
7676 tif = None
7777
@@ -98,17 +98,20 @@ def expl(x):
9898
9999 # set index so we can apply to the whole dataset without needing to skip X
100100 # and Y then reset in the index because that's what metric.do expects
101- exploded = data_in .set_index (['X' , 'Y' ]).apply (expl )[attrs ].reset_index ()
102- metric_data = dask .persist (
103- * [m .do (exploded ) for m in storage .config .metrics ]
104- )
101+ data_in = data_in .set_index (['Y' , 'X' ])
102+ exploded = data_in .apply (expl )[attrs ].reset_index ()
105103
106- data_out = data_in .set_index (['X' , 'Y' ]).join ([m for m in metric_data ])
107- return data_out
104+ exploded .rename (columns = {'X' : 'xi' , 'Y' : 'yi' }, inplace = True )
105+ graph = Graph (storage .config .metrics )
106+ metric_data = graph .run (exploded )
107+ #rename index from xi,yi to X,Y
108+ metric_data .index = metric_data .index .rename (['Y' ,'X' ])
109+
110+ return metric_data
108111
109112
110113def handle_overlaps (
111- config : ExtractConfig , storage : Storage , indices : np . ndarray
114+ config : ExtractConfig , storage : Storage , extents : Extents
112115) -> pd .DataFrame :
113116 """
114117 Handle cells that have overlapping data. We have to re-run metrics over
@@ -124,10 +127,10 @@ def handle_overlaps(
124127 ma_list = storage .getDerivedNames ()
125128 att_list = [a .name for a in config .attrs ]
126129
127- minx = indices [ 'x' ]. min ()
128- maxx = indices [ 'x' ]. max ()
129- miny = indices [ 'y' ]. min ()
130- maxy = indices [ 'y' ]. max ()
130+ minx = extents . x1
131+ maxx = extents . x2
132+ miny = extents . y1
133+ maxy = extents . y2
131134
132135 att_meta = {}
133136 att_meta ['X' ] = np .int32
@@ -137,44 +140,43 @@ def handle_overlaps(
137140 att_meta [a .name ] = a .dtype
138141
139142 with storage .open ('r' ) as tdb :
140- # TODO this can be more efficient. Use count to find indices, then work
141- # with that smaller set from there. Working as is for now, but slow.
142- dit = tdb .query (
143- attrs = [* att_list , * ma_list ],
143+ storage .config .log .info ('Looking for overlaps...' )
144+ data = tdb .query (
145+ attrs = [* ma_list ],
144146 order = 'F' ,
145147 coords = True ,
146- return_incomplete = True ,
147- use_arrow = False ,
148148 ).df [minx :maxx , miny :maxy ]
149- data = pd .DataFrame ()
150-
151- storage .config .log .info ('Collecting database information...' )
152- for d in dit :
153- if data .empty :
154- data = d
155- else :
156- data = pd .concat ([data , d ])
149+ data = data
157150
158151 # find values that are not unique, means they have multiple entries
159- data = data .set_index (['X ' , 'Y ' ])
152+ data = data .set_index (['Y ' , 'X ' ])
160153 redo_indices = data .index [data .index .duplicated (keep = 'first' )]
161154 if redo_indices .empty :
162- return data .reset_index ()
155+ storage .config .log .info ('No overlapping data. Continuing...' )
156+ return data
163157
164- # data with overlaps
165158 redo_data = (
166- data .loc [redo_indices ][att_list ]
167- .groupby (['X' , 'Y' ])
168- .agg (lambda x : list (chain (* x )))
159+ tdb .query (
160+ attrs = [* att_list ],
161+ order = 'F' ,
162+ coords = True ,
163+ use_arrow = False ,
164+ )
165+ .df [:, :]
166+ .set_index (['Y' , 'X' ])
169167 )
168+
169+ # data with overlaps
170+ redo_data = redo_data .loc [redo_indices ]
171+
170172 # data that has no overlaps
171173 clean_data = data .loc [data .index [~ data .index .duplicated (False )]]
172174
173175 storage .config .log .warning (
174176 'Overlapping data detected. Rerunning metrics over these cells...'
175177 )
176178 new_metrics = get_metrics (redo_data .reset_index (), storage )
177- return pd .concat ([clean_data , new_metrics ]). reset_index ()
179+ return pd .concat ([clean_data , new_metrics ])
178180
179181
180182def extract (config : ExtractConfig ) -> None :
@@ -197,28 +199,43 @@ def extract(config: ExtractConfig) -> None:
197199 storage .config .alignment ,
198200 root = root_bounds ,
199201 )
200- i = e .get_indices ()
201- xsize = e .x2
202- ysize = e .y2
203202
204203 # figure out if there are any overlaps and handle them
205- final = handle_overlaps (config , storage , i ).sort_values (['Y' , 'X' ])
204+ final = handle_overlaps (config , storage , e )
205+
206+ xis = final .index .get_level_values (1 ).astype (np .int64 )
207+ yis = final .index .get_level_values (0 ).astype (np .int64 )
208+ new_idx = pd .MultiIndex .from_product (
209+ (range (yis .min (), yis .max () + 1 ), range (xis .min (), xis .max () + 1 ))
210+ ).rename (['Y' ,'X' ])
211+ final = final .reindex (new_idx )
212+
213+ xs = root_bounds .minx + xis * config .resolution
214+ ys = root_bounds .maxy - yis * config .resolution
215+ final_bounds = Bounds (xs .min (), ys .min (), xs .max (), ys .max ())
206216
207217 # output metric data to tifs
208218 config .log .info (f'Writing rasters to { config .out_dir } ' )
219+ futures = []
209220 for ma in ma_list :
210221 # TODO should output in sections so we don't run into memory problems
211222 dtype = final [ma ].dtype
212- if dtype .kind in [ 'u' , 'i' ] :
213- nan_val = np . iinfo ( dtype ). max
214- elif dtype .kind == 'f' :
215- nan_val = np . nan
223+ if dtype .kind == 'u' :
224+ nan_val = 0
225+ elif dtype .kind in [ 'i' , 'f' ] :
226+ nan_val = - 9999
216227 else :
217- raise ValueError ('Invalid Raster data type {dtype}.' )
228+ nan_val = 0
229+ unstacked = final [ma ].unstack ()
230+ m_data = unstacked .to_numpy ()
218231
219- m_data = np .full (shape = (ysize , xsize ), fill_value = nan_val , dtype = dtype )
220- a = final [['X' , 'Y' , ma ]].to_numpy ()
221- for x , y , md in a [:]:
222- m_data [int (y )][int (x )] = md
232+ futures .append (
233+ dask .delayed (write_tif )(final_bounds , m_data , nan_val , ma , config )
234+ )
223235
224- write_tif (xsize , ysize , m_data , nan_val , ma , config )
236+ dc = get_client ()
237+ if dc is not None :
238+ dask .compute (* futures )
239+ else :
240+ with ProgressBar ():
241+ dask .compute (* futures )
0 commit comments