Skip to content

Commit 16a2771

Browse files
authored
Merge pull request #92 from chrishavlin/reorganize_chunking_ops
Reorganize chunking operations
2 parents b6880a8 + 7061d97 commit 16a2771

File tree

3 files changed

+150
-32
lines changed

3 files changed

+150
-32
lines changed

yt_xarray/accessor/accessor.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from yt_xarray.accessor import _xr_to_yt
1111
from yt_xarray.accessor._readers import _get_xarray_reader
1212
from yt_xarray.accessor._xr_to_yt import _load_full_field_from_xr
13+
from yt_xarray.utilities._grid_decomposition import ChunkInfo
1314
from yt_xarray.utilities.logging import ytxr_log
1415

1516

@@ -535,18 +536,14 @@ def _load_chunked_grid(
535536
# otherwise it is number of nodes (which are treated as new cell centers).
536537
# the bbox will already account for this as well.
537538

538-
# do some grid/chunk counting
539-
n_chnk = np.asarray(data_shp) / chunksizes # may not be int
540-
n_whl_chnk = np.floor(n_chnk).astype(int) # whole chunks in each dim
541-
n_part_chnk = np.ceil(n_chnk - n_whl_chnk).astype(int) # partial chunks
542-
543-
n_tots = np.prod(n_part_chnk + n_whl_chnk)
544-
ytxr_log.info(f"Constructing a yt chunked grid with {n_tots} chunks.")
545-
546539
# initialize the global starting index
547540
si = np.array([0, 0, 0], dtype=int)
548541
si = sel_info.starting_indices + si
549542

543+
# do some grid/chunk counting
544+
chnkinfo = ChunkInfo(data_shp, chunksizes, starting_index_offset=si)
545+
ytxr_log.info(f"Constructing a yt chunked grid with {chnkinfo.n_tots} chunks.")
546+
550547
# select field for grabbing coordinate arrays -- fields should all be
551548
# verified by now
552549
fld = fields[0]
@@ -564,29 +561,8 @@ def _load_chunked_grid(
564561
subgrid_start = []
565562
subgrid_end = []
566563
for idim in range(sel_info.ndims):
567-
si_0 = si[idim] + chunksizes[idim] * np.arange(n_whl_chnk[idim])
568-
ei_0 = si_0 + chunksizes[idim]
569-
570-
if n_part_chnk[idim] == 1:
571-
si_0_partial = ei_0[-1]
572-
ei_0_partial = data_shp[idim] - si_0_partial
573-
si_0 = np.concatenate(
574-
[
575-
si_0,
576-
[
577-
si_0_partial,
578-
],
579-
]
580-
)
581-
ei_0 = np.concatenate(
582-
[
583-
ei_0,
584-
[
585-
ei_0[-1] + ei_0_partial,
586-
],
587-
]
588-
)
589-
564+
si_0 = chnkinfo.si[idim]
565+
ei_0 = chnkinfo.ei[idim]
590566
c = cnames[idim]
591567
rev_ax = sel_info.reverse_axis[idim]
592568
if rev_ax is False:
@@ -608,7 +584,7 @@ def _load_chunked_grid(
608584
le_0 = np.concatenate([[min_val], re_0[:-1]])
609585

610586
# sizes also already account for interp_required
611-
subgrid_size = ei_0 - si_0
587+
subgrid_size = chnkinfo.sizes[idim]
612588

613589
left_edges.append(le_0)
614590
right_edges.append(re_0)

yt_xarray/tests/test_chunking.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import yt_xarray # noqa: F401
88
from yt_xarray import sample_data
9+
from yt_xarray.utilities._grid_decomposition import ChunkInfo
910
from yt_xarray.utilities._utilities import construct_minimal_ds
1011

1112

@@ -146,3 +147,41 @@ def test_chunk_bad_length():
146147

147148
with pytest.raises(ValueError, match="The number of elements in "):
148149
_ = ds.yt.load_grid(length_unit="km", chunksizes=(30, 40, 20, 5))
150+
151+
152+
_chunk_tests = [
153+
((20, 30, 40), (10, 15, 20), (0,) * 3, (2, 2, 2)),
154+
((20, 30, 40), (15, 15, 20), (0,) * 3, (2, 2, 2)),
155+
((10, 15, 20), (5, 5, 5), None, (2, 3, 4)),
156+
((10, 15, 20), (5, 5, 5), (1, 2, 3), (2, 3, 4)),
157+
]
158+
159+
160+
@pytest.mark.parametrize("data_shape,chunksizes,si0, expected_nchunks", _chunk_tests)
161+
def test_chunk_info(data_shape, chunksizes, si0, expected_nchunks):
162+
chunksizes = np.array(chunksizes, dtype="int")
163+
if si0 is not None:
164+
si0 = np.array(si0, dtype="int")
165+
ch = ChunkInfo(data_shape, chunksizes, starting_index_offset=si0)
166+
chunks = ch.n_whl_chnk + ch.n_part_chnk
167+
assert np.all(chunks == np.asarray(expected_nchunks))
168+
if si0 is not None:
169+
si = np.array([ch.si[id][0] for id in range(3)])
170+
assert np.all(si == si0)
171+
172+
173+
def test_chunk_info_caching():
174+
175+
chunksizes = np.array([5, 5, 5], dtype="int")
176+
data_shape = (10, 15, 20)
177+
178+
def _get_ch():
179+
return ChunkInfo(data_shape, chunksizes)
180+
181+
ch = _get_ch()
182+
_ = ch.ei
183+
ch = _get_ch()
184+
_ = ch.sizes
185+
assert ch._sizes is not None
186+
assert ch._si is not None
187+
assert ch._ei is not None

yt_xarray/utilities/_grid_decomposition.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,106 @@ def _get_yt_ds(
421421
refine_by=refine_by,
422422
**load_kwargs,
423423
)
424+
425+
426+
class ChunkInfo:
427+
"""
428+
Class for tracking info related to chunked-decomposition of a domain
429+
430+
Parameters
431+
----------
432+
data_shp: Tuple[int,]
433+
the global shape of the data to chunk
434+
chunksizes: np.ndarray[int]
435+
the chunksizes in each dimension of data_shp
436+
starting_index_offset: np.ndarray[int]
437+
global index offset. start and end indices will be offset
438+
by this array. Defaults to [0,0,0].
439+
"""
440+
441+
def __init__(
442+
self,
443+
data_shp: Tuple[int,],
444+
chunksizes: np.ndarray,
445+
starting_index_offset: np.ndarray = None,
446+
):
447+
448+
self.chunksizes = chunksizes
449+
self.data_shape = np.asarray(data_shp)
450+
self.n_chnk = self.data_shape / chunksizes # may not be int
451+
self.n_whl_chnk = np.floor(self.n_chnk).astype(int) # whole chunks in each dim
452+
self.n_part_chnk = np.ceil(self.n_chnk - self.n_whl_chnk).astype(int)
453+
self.n_tots = np.prod(self.n_part_chnk + self.n_whl_chnk)
454+
455+
self.ndim = len(data_shp)
456+
if starting_index_offset is None:
457+
starting_index_offset = np.zeros(self.data_shape.shape, dtype=int)
458+
self.starting_index_offset = starting_index_offset
459+
460+
_si: List[np.ndarray] = None
461+
_ei: List[np.ndarray] = None
462+
_sizes: List[np.ndarray] = None
463+
464+
@property
465+
def si(self) -> List[np.ndarray]:
466+
"""
467+
The starting indices of individual chunks by dimension.
468+
Includes any global offset.
469+
"""
470+
if self._si is None:
471+
si_list = []
472+
ei_list = []
473+
size_list = []
474+
for idim in range(self.ndim):
475+
476+
# first get the starting and end points of whole chunks
477+
si0 = self.starting_index_offset[idim]
478+
si_0 = si0 + self.chunksizes[idim] * np.arange(self.n_whl_chnk[idim])
479+
ei_0 = si_0 + self.chunksizes[idim]
480+
481+
# if this dim has a partial chunk at the end, add on a
482+
# partial chunk.
483+
if self.n_part_chnk[idim] == 1:
484+
si_0_partial = ei_0[-1]
485+
ei_0_partial = self.data_shape[idim] - si_0_partial
486+
si_0 = np.concatenate(
487+
[
488+
si_0,
489+
[
490+
si_0_partial,
491+
],
492+
]
493+
)
494+
ei_0 = np.concatenate(
495+
[
496+
ei_0,
497+
[
498+
ei_0[-1] + ei_0_partial,
499+
],
500+
]
501+
)
502+
si_list.append(si_0)
503+
ei_list.append(ei_0)
504+
size_list.append(ei_0 - si_0)
505+
self._si = si_list
506+
self._ei = ei_list
507+
self._sizes = size_list
508+
return self._si
509+
510+
@property
511+
def ei(self) -> List[np.ndarray]:
512+
"""
513+
The ending indices of individual chunks by dimension.
514+
Includes any global offset.
515+
"""
516+
if self._ei is None:
517+
_ = self.si
518+
assert self._ei is not None
519+
return self._ei
520+
521+
@property
522+
def sizes(self) -> List[np.ndarray]:
523+
if self._sizes is None:
524+
_ = self.si
525+
assert self._sizes is not None
526+
return self._sizes

0 commit comments

Comments
 (0)