Skip to content

Commit 87aef5e

Browse files
authored
Add mem_size histogram (#619)
* Set up initial branching logic to optionally create mem_size histogram * Create mem_size histogram and use in mapping stage (not yet used for partitioning) * Add unit tests for mem_size hist * Add pragma no cover to read_histogram logic that will be tested in next PR
1 parent fb2e1f3 commit 87aef5e

File tree

8 files changed

+484
-35
lines changed

8 files changed

+484
-35
lines changed

src/hats_import/catalog/arguments.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ class ImportArguments(RuntimeArguments):
8787
"""when determining bins for the final partitioning, the maximum number
8888
of rows for a single resulting pixel. we may combine hierarchically until
8989
we near the ``pixel_threshold``"""
90+
byte_pixel_threshold: int | None = None
91+
"""when determining bins for the final partitioning, the maximum number
92+
of rows for a single resulting pixel, expressed in bytes. we may combine hierarchically until
93+
we near the ``byte_pixel_threshold``. if this is set, it will override
94+
``pixel_threshold``."""
9095
drop_empty_siblings: bool = True
9196
"""when determining bins for the final partitioning, should we keep result pixels
9297
at a higher order (smaller area) if the 3 sibling pixels are empty. setting this to
@@ -144,6 +149,13 @@ def _check_arguments(self):
144149
if self.sort_columns:
145150
raise ValueError("When using _healpix_29 for position, no sort columns should be added")
146151

152+
# Validate byte_pixel_threshold
153+
if self.byte_pixel_threshold is not None:
154+
if not isinstance(self.byte_pixel_threshold, int):
155+
raise TypeError("byte_pixel_threshold must be an integer")
156+
if self.byte_pixel_threshold < 0:
157+
raise ValueError("byte_pixel_threshold must be non-negative")
158+
147159
# Basic checks complete - make more checks and create directories where necessary
148160
self.input_paths = find_input_paths(self.input_path, "**/*.*", self.input_file_list)
149161

src/hats_import/catalog/map_reduce.py

Lines changed: 111 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Import a set of non-hats files using dask for parallelization"""
22

33
import pickle
4+
import sys
5+
from collections import defaultdict
46

57
import cloudpickle
68
import hats.pixel_math.healpix_shim as hp
@@ -86,6 +88,7 @@ def map_to_pixels(
8688
ra_column,
8789
dec_column,
8890
use_healpix_29=False,
91+
threshold_mode="row_count",
8992
):
9093
"""Map a file of input objects to their healpix pixels.
9194
@@ -99,6 +102,7 @@ def map_to_pixels(
99102
highest_order (int): healpix order to use when mapping
100103
ra_column (str): where to find right ascension data in the dataframe
101104
dec_column (str): where to find declation in the dataframe
105+
threshold_mode (str): mode for thresholding, either "row_count" or "mem_size".
102106
103107
Returns:
104108
one-dimensional numpy array of long integers where the value at each index corresponds
@@ -108,14 +112,24 @@ def map_to_pixels(
108112
FileNotFoundError: if the file does not exist, or is a directory
109113
"""
110114
try:
111-
histo = HistogramAggregator(highest_order)
112-
113-
if use_healpix_29:
115+
# Always generate the row-count histogram.
116+
row_count_histo = HistogramAggregator(highest_order)
117+
mem_size_histo = None
118+
if threshold_mode == "mem_size":
119+
mem_size_histo = HistogramAggregator(highest_order)
120+
121+
# Determine which columns to read from the input file. If we're using
122+
# the bytewise/mem_size histogram, we need to read all columns to accurately
123+
# estimate memory usage.
124+
if threshold_mode == "mem_size":
125+
read_columns = None
126+
elif use_healpix_29:
114127
read_columns = [SPATIAL_INDEX_COLUMN]
115128
else:
116129
read_columns = [ra_column, dec_column]
117130

118-
for _, _, mapped_pixels in _iterate_input_file(
131+
# Iterate through the input file in chunks, mapping pixels and updating histograms.
132+
for _, chunk_data, mapped_pixels in _iterate_input_file(
119133
input_file,
120134
pickled_reader_file,
121135
highest_order,
@@ -124,18 +138,108 @@ def map_to_pixels(
124138
use_healpix_29,
125139
read_columns,
126140
):
141+
# Always add to row_count histogram.
127142
mapped_pixel, count_at_pixel = np.unique(mapped_pixels, return_counts=True)
143+
row_count_histo.add(SparseHistogram(mapped_pixel, count_at_pixel, highest_order))
144+
145+
# If using bytewise/mem_size thresholding, also add to mem_size histogram.
146+
if threshold_mode == "mem_size":
147+
data_mem_sizes = _get_mem_size_of_chunk(chunk_data)
148+
pixel_mem_sizes: dict[int, int] = defaultdict(int)
149+
for pixel, mem_size in zip(mapped_pixels, data_mem_sizes, strict=True):
150+
pixel_mem_sizes[pixel] += mem_size
151+
152+
# Turn our dict into two lists, the keys and vals, sorted so the keys are increasing
153+
mapped_pixel_ids = np.array(list(pixel_mem_sizes.keys()), dtype=np.int64)
154+
mapped_pixel_mem_sizes = np.array(list(pixel_mem_sizes.values()), dtype=np.int64)
155+
156+
if mem_size_histo is not None:
157+
mem_size_histo.add(
158+
SparseHistogram(mapped_pixel_ids, mapped_pixel_mem_sizes, highest_order)
159+
)
128160

129-
histo.add(SparseHistogram(mapped_pixel, count_at_pixel, highest_order))
130-
131-
histo.to_sparse().to_file(
161+
# Write row_count histogram to file.
162+
row_count_histo.to_sparse().to_file(
132163
ResumePlan.partial_histogram_file(tmp_path=resume_path, mapping_key=mapping_key)
133164
)
165+
# If using bytewise/mem_size thresholding, also write mem_size histogram to a separate file.
166+
if threshold_mode == "mem_size" and mem_size_histo is not None:
167+
mem_size_histo.to_sparse().to_file(
168+
ResumePlan.partial_histogram_file(
169+
tmp_path=resume_path, mapping_key=f"{mapping_key}", which_histogram="mem_size"
170+
)
171+
)
134172
except Exception as exception: # pylint: disable=broad-exception-caught
135173
print_task_failure(f"Failed MAPPING stage with file {input_file}", exception)
136174
raise exception
137175

138176

177+
def _get_row_mem_size_data_frame(row):
178+
"""Given a pandas dataframe row (as a tuple), return the memory size of that row.
179+
180+
Args:
181+
row (tuple): the row from the dataframe
182+
183+
Returns:
184+
int: the memory size of the row in bytes
185+
"""
186+
total = 0
187+
188+
# Add the memory overhead of the row object itself.
189+
total += sys.getsizeof(row)
190+
191+
# Then add the size of each item in the row.
192+
for item in row:
193+
if isinstance(item, np.ndarray):
194+
total += item.nbytes + sys.getsizeof(item) # object data + object overhead
195+
else:
196+
total += sys.getsizeof(item)
197+
return total
198+
199+
200+
def _get_row_mem_size_pa_table(table, row_index):
201+
"""Given a pyarrow table and a row index, return the memory size of that row.
202+
203+
Args:
204+
table (pa.Table): the pyarrow table
205+
row_index (int): the index of the row to measure
206+
207+
Returns:
208+
int: the memory size of the row in bytes
209+
"""
210+
total = 0
211+
212+
# Add the memory overhead of the row object itself.
213+
total += sys.getsizeof(row_index)
214+
215+
# Then add the size of each item in the row.
216+
for column in table.itercolumns():
217+
item = column[row_index]
218+
if isinstance(item, np.ndarray):
219+
total += item.nbytes + sys.getsizeof(item) # object data + object overhead
220+
else:
221+
total += sys.getsizeof(item.as_py())
222+
return total
223+
224+
225+
def _get_mem_size_of_chunk(data):
226+
"""Given a 2D array of data, return a list of memory sizes for each row in the chunk.
227+
228+
Args:
229+
data (pd.DataFrame or pa.Table): the data chunk to measure
230+
231+
Returns:
232+
list[int]: list of memory sizes for each row in the chunk
233+
"""
234+
if isinstance(data, pd.DataFrame):
235+
mem_sizes = [_get_row_mem_size_data_frame(row) for row in data.itertuples(index=False, name=None)]
236+
elif isinstance(data, pa.Table):
237+
mem_sizes = [_get_row_mem_size_pa_table(data, i) for i in range(data.num_rows)]
238+
else:
239+
raise NotImplementedError(f"Unsupported data type {type(data)} for memory size calculation")
240+
return mem_sizes
241+
242+
139243
def split_pixels(
140244
input_file: UPath,
141245
pickled_reader_file: str,

0 commit comments

Comments
 (0)