diff --git a/src/hats/pixel_math/partition_stats.py b/src/hats/pixel_math/partition_stats.py index bc31345b..656c03e0 100644 --- a/src/hats/pixel_math/partition_stats.py +++ b/src/hats/pixel_math/partition_stats.py @@ -1,7 +1,10 @@ """Utilities for generating and manipulating object count histograms""" +import logging + import numpy as np import pandas as pd +import pyarrow.parquet as pq import hats.pixel_math.healpix_shim as hp @@ -54,7 +57,12 @@ def generate_histogram( def generate_alignment( - histogram, highest_order=10, lowest_order=0, threshold=1_000_000, drop_empty_siblings=False + histogram, + highest_order=10, + lowest_order=0, + threshold=1_000_000, + byte_pixel_threshold=None, + drop_empty_siblings=False, ): """Generate alignment from high order pixels to those of equal or lower order @@ -66,20 +74,27 @@ def generate_alignment( Args: histogram (:obj:`np.array`): one-dimensional numpy array of long integers where the value at each index corresponds to the number of objects found at the healpix pixel. - highest_order (int): the highest healpix order (e.g. 5-10) + highest_order (int): the highest healpix order (e.g. 5-10) lowest_order (int): the lowest healpix order (e.g. 1-5). specifying a lowest order constrains the partitioning to prevent spatially large pixels. threshold (int): the maximum number of objects allowed in a single pixel + byte_pixel_threshold (int | None): the maximum number of objects allowed in a single pixel, + expressed in bytes. if this is set, it will override `threshold`. drop_empty_siblings (bool): if 3 of 4 pixels are empty, keep only the non-empty pixel + Returns: one-dimensional numpy array of integer 3-tuples, where the value at each index corresponds to the destination pixel at order less than or equal to the `highest_order`. The tuple contains three integers: - - order of the destination pixel - pixel number *at the above order* - - the number of objects in the pixel + - the number of objects in the pixel (if partitioning by row count), or the memory size (if + partitioning by memory) + + Note: + If partitioning is done by memory size, the row count per partition may vary widely and will + not match the row count histogram's bins. Raises: ValueError: if the histogram is the wrong size, or some initial histogram bins exceed threshold. @@ -88,9 +103,21 @@ def generate_alignment( raise ValueError("histogram is not the right size") if lowest_order > highest_order: raise ValueError("lowest_order should be less than highest_order") + + # Determine aggregation type and threshold + if byte_pixel_threshold is not None: + agg_threshold = byte_pixel_threshold + agg_type = "mem_size" + else: + agg_threshold = threshold + agg_type = "row_count" + + # Check that none of the high-order pixels already exceed the threshold. max_bin = np.amax(histogram) - if max_bin > threshold: - raise ValueError(f"single pixel count {max_bin} exceeds threshold {threshold}") + if agg_type == "mem_size" and max_bin > agg_threshold: + raise ValueError(f"single pixel size {max_bin} bytes exceeds byte_pixel_threshold {agg_threshold}") + if agg_type == "row_count" and max_bin > agg_threshold: + raise ValueError(f"single pixel count {max_bin} exceeds threshold {agg_threshold}") nested_sums = [] for i in range(0, highest_order): @@ -104,9 +131,10 @@ def generate_alignment( parent_pixel = index >> 2 nested_sums[parent_order][parent_pixel] += nested_sums[read_order][index] + # Use the aggregation threshold for alignment if drop_empty_siblings: - return _get_alignment_dropping_siblings(nested_sums, highest_order, lowest_order, threshold) - return _get_alignment(nested_sums, highest_order, lowest_order, threshold) + return _get_alignment_dropping_siblings(nested_sums, highest_order, lowest_order, agg_threshold) + return _get_alignment(nested_sums, highest_order, lowest_order, agg_threshold) def _get_alignment(nested_sums, highest_order, lowest_order, threshold): @@ -129,9 +157,9 @@ def _get_alignment(nested_sums, highest_order, lowest_order, threshold): if parent_alignment: nested_alignment[read_order][index] = parent_alignment - elif nested_sums[read_order][index] == 0: + elif nested_sums[read_order][index] == 0: # pylint: disable=no-else-raise continue - elif nested_sums[read_order][index] <= threshold: + elif nested_sums[read_order][index] <= threshold: # pylint: disable=no-else-raise nested_alignment[read_order][index] = ( read_order, index, @@ -201,3 +229,41 @@ def _get_alignment_dropping_siblings(nested_sums, highest_order, lowest_order, t ] return np.array(nested_alignment, dtype="object") + + +def generate_row_count_histogram_from_partitions(partition_files, pixel_orders, pixel_indices, highest_order): + """Generate a row count histogram from a list of partition files and their pixel indices/orders. + + Args: + partition_files (list[str or UPath]): List of paths to partition files. + pixel_orders (list[int]): List of healpix orders for each partition. + pixel_indices (list[int]): List of healpix pixel indices for each partition. + highest_order (int): The highest healpix order (for histogram size). + + Returns: + np.ndarray: One-dimensional numpy array of long integers, where the value at each index + corresponds to the number of rows found at the healpix pixel. + + Note: + If partitioning was done by memory size, this histogram will reflect the actual row counts + in the output partitions, which may differ significantly from the original row count histogram. + """ + histogram = np.zeros(hp.order2npix(highest_order), dtype=np.int64) + for file_path, order, pix in zip(partition_files, pixel_orders, pixel_indices): + try: + table = pq.read_table(file_path) + row_count = len(table) + # Map pixel index to highest_order if needed + if order == highest_order: + histogram[pix] += row_count + else: + # Map lower order pixel to highest_order pixel indices + # Each lower order pixel covers 4**(highest_order - order) pixels + factor = 4 ** (highest_order - order) + start = pix * factor + end = (pix + 1) * factor + histogram[start:end] += row_count // factor + except (OSError, ValueError) as e: + logging.warning("Could not read partition file %s: %s", file_path, e) + continue + return histogram