@@ -209,18 +209,31 @@ Fill the pyramids generated from the `data` with the aggregation function `func`
209209`recursive` indicates whether higher tiles are computed from lower tiles or directly from the original data.
210210This is an optimization which for functions like median might lead to misleading results.
211211"""
212- function fill_pyramids (data, outputs,func,recursive;runner= LocalRunner, kwargs... )
212+ function fill_pyramids (data, outputs,func,recursive,spatial_dims ;runner= LocalRunner, kwargs... )
213213 n_level = length (outputs)
214214 pixel_base_size = 2 ^ n_level
215215 pyramid_sizes = size .(outputs)
216216 tmp_sizes = [ceil (Int,pixel_base_size / 2 ^ i) for i in 1 : n_level]
217217
218- ia = InputArray (data, windows = arraywindows (size (data),pixel_base_size))
218+ inputwindows = Base. OneTo .(size (data))
219+ outputwindows = Base. OneTo .(size (data))
220+ for d in spatial_dims
221+ inputwindows = Base. setindex (inputwindows,RegularWindows (1 ,size (data,d),window= pixel_base_size),d)
222+ end
223+
224+ ia = InputArray (data, windows = inputwindows)
225+
219226
220- oa = ntuple (i-> create_outwindows (pyramid_sizes[i],windows = arraywindows (pyramid_sizes[i],tmp_sizes[i])),n_level)
227+ oa = ntuple (n_level) do i
228+ for d in spatial_dims
229+ outputwindows = Base. setindex (outputwindows,RegularWindows (1 ,pyramid_sizes[i][d],window= tmp_sizes[i]),d)
230+ end
231+ create_outwindows (pyramid_sizes[i],windows = outputwindows)
232+ end
221233
222234 func = DiskArrayEngine. create_userfunction (gen_pyr,ntuple (_-> eltype (first (outputs)),length (outputs));is_mutating= true ,kwargs = (;recursive),args = (func,))
223235
236+
224237 op = GMDWop ((ia,), oa, func)
225238
226239 lr = DiskArrayEngine. optimize_loopranges (op,5e8 ,tol_low= 0.2 ,tol_high= 0.05 ,max_order= 2 )
@@ -317,7 +330,7 @@ output_arrays(pyramid_sizes, T) = [gen_output(T,p) for p in pyramid_sizes]
317330 SpatialDim
318331Union of Dimensions which are assumed to be in space and are therefore used in the pyramid building.
319332"""
320- SpatialDim = Union{ DD. Dimensions. XDim, DD. Dimensions. YDim}
333+ SpatialDim = ( DD. Dimensions. XDim, DD. Dimensions. YDim)
321334
322335"""
323336 buildpyramids(path; resampling_method=mean)
@@ -326,7 +339,7 @@ The different scales are written according to the GeoZarr spec and a multiscales
326339The data is aggregated with the specified `resampling_method`.
327340Keyword arguments are forwarded to the `fill_pyramids` function.
328341"""
329- function buildpyramids (path:: AbstractString ; resampling_method= mean, recursive= true , runner= LocalRunner, verbose= false )
342+ function buildpyramids (path:: AbstractString ; resampling_method= mean, recursive= true , runner= LocalRunner, verbose= false , spatial_dims = SpatialDim )
330343 if YAB. backendfrompath (path) != YAB. backendlist[:zarr ]
331344 @show YAB. backendfrompath (path)
332345 throw (ArgumentError (" $path is not a Zarr dataset therefore we can't build the Pyramids inplace" ))
@@ -340,17 +353,19 @@ function buildpyramids(path::AbstractString; resampling_method=mean, recursive=t
340353 # t = Missing <: eltype(org) ? Union{Missing, tfunc} : tfunc
341354
342355 t = Base. infer_return_type (resampling_method, (Matrix{nonmissingtype (eltype (org))},))
343-
344356 n_level = compute_nlevels (org)
345- input_axes = filter (x-> x isa SpatialDim, DD. dims (org))
357+ input_axes = DD. dims (org, spatial_dims)
358+ outarrs = [output_zarr (n, DD. dims (org), t, joinpath (path, string (n)),input_axes) for n in 1 : n_level]
359+
346360 if length (input_axes) != 2
347361 throw (ArgumentError (" Expected two spatial dimensions got $input_axes " ))
348362 end
349363 verbose && println (" Constructing output arrays" )
350- outarrs = [output_zarr (n, input_axes, t, joinpath (path, string (n))) for n in 1 : n_level]
351364 verbose && println (" Start computation" )
352- fill_pyramids (org, outarrs, resampling_method, recursive;runner)
353- pyraxs = [agg_axis .(input_axes, 2 ^ n) for n in 1 : n_level]
365+ ispatial_dims = DD. dimnum (DD. dims (org),spatial_dims)
366+ fill_pyramids (org, outarrs, resampling_method, recursive, ispatial_dims;runner)
367+ pyraxs_space = [agg_axis .(input_axes, 2 ^ n) for n in 1 : n_level]
368+ pyraxs = [DD. setdims (DD. dims (org),p) for p in pyraxs_space]
354369 pyrlevels = DD. DimArray .(outarrs, pyraxs)
355370 meta = Dict (deepcopy (DD. metadata (org)))
356371 push! (meta, " resampling_method" => string (resampling_method))
@@ -370,15 +385,23 @@ end
370385
371386"""
372387 output_zarr(n, input_axes, t, path)
388+
373389Construct a Zarr dataset for the level n of a pyramid for the dimensions `input_axes`.
374390It sets the type to `t` and saves it to `path/n`
375391"""
376- function output_zarr (n, input_axes, t, path)
377- aggdims = agg_axis .(input_axes, 2 ^ n)
378- s = length .(aggdims)
392+ function output_zarr (n, input_axes, t, path, spatialdims;chunksizes= nothing )
393+ spatial_axes = DD. dims (input_axes,spatialdims)
394+ aggdims = agg_axis .(spatial_axes, 2 ^ n)
395+ spatial_axes_new = DD. setdims (input_axes,aggdims)
396+ s = length .(spatial_axes_new)
379397 z = Zeros (t, s... )
380- yax = YAXArray (aggdims, z)
381- chunked = setchunks (yax , (1024 , 1024 ))
398+ yax = YAXArray (spatial_axes_new, z)
399+ if chunksizes === nothing
400+ dchunksizes = map (d-> DD. rebuild (d,1 ),input_axes)
401+ dchunksizes = DD. setdims (dchunksizes,map (d-> DD. rebuild (d,1024 ),spatial_axes))
402+ chunksizes = map (i-> i. val,dchunksizes)
403+ end
404+ chunked = setchunks (yax , chunksizes)
382405 # This assumes that there is only the spatial dimensions to save
383406 ds = to_dataset (chunked, )
384407 dssaved = savedataset (ds; path, skeleton= true , driver= :zarr )
391414Compute the data of the pyramids of a given data cube `ras`.
392415This returns the data of the pyramids and the dimension values of the aggregated axes.
393416"""
394- function getpyramids (reducefunc, ras;recursive= true , tilesize= 256 )
417+ function getpyramids (reducefunc, ras;recursive= true , tilesize= 256 , spatial_dims = SpatialDim )
395418 input_axes = DD. dims (ras)
396419 n_level = compute_nlevels (ras, tilesize)
397420 if iszero (n_level)
@@ -403,7 +426,9 @@ function getpyramids(reducefunc, ras;recursive=true, tilesize=256)
403426 outtype = Base. infer_return_type (reducefunc, (Matrix{eltype (ras)},))
404427 # outtype = Missing <: eltype(ras) ? Union{Missing, outtypefunc} : outtypefunc
405428 outmin = output_arrays (pyramid_sizes, outtype)
406- fill_pyramids (ras,outmin,reducefunc,recursive; threaded= true )
429+ ispatial_dims = DD. dimnum (DD. dims (ras),spatial_dims)
430+
431+ fill_pyramids (ras,outmin,reducefunc,recursive, ispatial_dims; threaded= true )
407432
408433 outmin, pyramid_axes
409434end
0 commit comments