diff --git a/src/hats_import/catalog/arguments.py b/src/hats_import/catalog/arguments.py index acd78990..8317cebf 100644 --- a/src/hats_import/catalog/arguments.py +++ b/src/hats_import/catalog/arguments.py @@ -87,6 +87,11 @@ class ImportArguments(RuntimeArguments): """when determining bins for the final partitioning, the maximum number of rows for a single resulting pixel. we may combine hierarchically until we near the ``pixel_threshold``""" + byte_pixel_threshold: int | None = None + """when determining bins for the final partitioning, the maximum number + of rows for a single resulting pixel, expressed in bytes. we may combine hierarchically until + we near the ``byte_pixel_threshold``. if this is set, it will override + ``pixel_threshold``.""" drop_empty_siblings: bool = True """when determining bins for the final partitioning, should we keep result pixels at a higher order (smaller area) if the 3 sibling pixels are empty. setting this to @@ -144,6 +149,13 @@ def _check_arguments(self): if self.sort_columns: raise ValueError("When using _healpix_29 for position, no sort columns should be added") + # Validate byte_pixel_threshold + if self.byte_pixel_threshold is not None: + if not isinstance(self.byte_pixel_threshold, int): + raise TypeError("byte_pixel_threshold must be an integer") + if self.byte_pixel_threshold < 0: + raise ValueError("byte_pixel_threshold must be non-negative") + # Basic checks complete - make more checks and create directories where necessary self.input_paths = find_input_paths(self.input_path, "**/*.*", self.input_file_list) diff --git a/src/hats_import/catalog/map_reduce.py b/src/hats_import/catalog/map_reduce.py index c867e23d..ddb03065 100644 --- a/src/hats_import/catalog/map_reduce.py +++ b/src/hats_import/catalog/map_reduce.py @@ -1,6 +1,8 @@ """Import a set of non-hats files using dask for parallelization""" import pickle +import sys +from collections import defaultdict import cloudpickle import hats.pixel_math.healpix_shim as hp @@ -86,6 +88,7 @@ def map_to_pixels( ra_column, dec_column, use_healpix_29=False, + threshold_mode="row_count", ): """Map a file of input objects to their healpix pixels. @@ -99,6 +102,7 @@ def map_to_pixels( highest_order (int): healpix order to use when mapping ra_column (str): where to find right ascension data in the dataframe dec_column (str): where to find declation in the dataframe + threshold_mode (str): mode for thresholding, either "row_count" or "mem_size". Returns: one-dimensional numpy array of long integers where the value at each index corresponds @@ -108,14 +112,24 @@ def map_to_pixels( FileNotFoundError: if the file does not exist, or is a directory """ try: - histo = HistogramAggregator(highest_order) - - if use_healpix_29: + # Always generate the row-count histogram. + row_count_histo = HistogramAggregator(highest_order) + mem_size_histo = None + if threshold_mode == "mem_size": + mem_size_histo = HistogramAggregator(highest_order) + + # Determine which columns to read from the input file. If we're using + # the bytewise/mem_size histogram, we need to read all columns to accurately + # estimate memory usage. + if threshold_mode == "mem_size": + read_columns = None + elif use_healpix_29: read_columns = [SPATIAL_INDEX_COLUMN] else: read_columns = [ra_column, dec_column] - for _, _, mapped_pixels in _iterate_input_file( + # Iterate through the input file in chunks, mapping pixels and updating histograms. + for _, chunk_data, mapped_pixels in _iterate_input_file( input_file, pickled_reader_file, highest_order, @@ -124,18 +138,108 @@ def map_to_pixels( use_healpix_29, read_columns, ): + # Always add to row_count histogram. mapped_pixel, count_at_pixel = np.unique(mapped_pixels, return_counts=True) + row_count_histo.add(SparseHistogram(mapped_pixel, count_at_pixel, highest_order)) + + # If using bytewise/mem_size thresholding, also add to mem_size histogram. + if threshold_mode == "mem_size": + data_mem_sizes = _get_mem_size_of_chunk(chunk_data) + pixel_mem_sizes: dict[int, int] = defaultdict(int) + for pixel, mem_size in zip(mapped_pixels, data_mem_sizes, strict=True): + pixel_mem_sizes[pixel] += mem_size + + # Turn our dict into two lists, the keys and vals, sorted so the keys are increasing + mapped_pixel_ids = np.array(list(pixel_mem_sizes.keys()), dtype=np.int64) + mapped_pixel_mem_sizes = np.array(list(pixel_mem_sizes.values()), dtype=np.int64) + + if mem_size_histo is not None: + mem_size_histo.add( + SparseHistogram(mapped_pixel_ids, mapped_pixel_mem_sizes, highest_order) + ) - histo.add(SparseHistogram(mapped_pixel, count_at_pixel, highest_order)) - - histo.to_sparse().to_file( + # Write row_count histogram to file. + row_count_histo.to_sparse().to_file( ResumePlan.partial_histogram_file(tmp_path=resume_path, mapping_key=mapping_key) ) + # If using bytewise/mem_size thresholding, also write mem_size histogram to a separate file. + if threshold_mode == "mem_size" and mem_size_histo is not None: + mem_size_histo.to_sparse().to_file( + ResumePlan.partial_histogram_file( + tmp_path=resume_path, mapping_key=f"{mapping_key}", which_histogram="mem_size" + ) + ) except Exception as exception: # pylint: disable=broad-exception-caught print_task_failure(f"Failed MAPPING stage with file {input_file}", exception) raise exception +def _get_row_mem_size_data_frame(row): + """Given a pandas dataframe row (as a tuple), return the memory size of that row. + + Args: + row (tuple): the row from the dataframe + + Returns: + int: the memory size of the row in bytes + """ + total = 0 + + # Add the memory overhead of the row object itself. + total += sys.getsizeof(row) + + # Then add the size of each item in the row. + for item in row: + if isinstance(item, np.ndarray): + total += item.nbytes + sys.getsizeof(item) # object data + object overhead + else: + total += sys.getsizeof(item) + return total + + +def _get_row_mem_size_pa_table(table, row_index): + """Given a pyarrow table and a row index, return the memory size of that row. + + Args: + table (pa.Table): the pyarrow table + row_index (int): the index of the row to measure + + Returns: + int: the memory size of the row in bytes + """ + total = 0 + + # Add the memory overhead of the row object itself. + total += sys.getsizeof(row_index) + + # Then add the size of each item in the row. + for column in table.itercolumns(): + item = column[row_index] + if isinstance(item, np.ndarray): + total += item.nbytes + sys.getsizeof(item) # object data + object overhead + else: + total += sys.getsizeof(item.as_py()) + return total + + +def _get_mem_size_of_chunk(data): + """Given a 2D array of data, return a list of memory sizes for each row in the chunk. + + Args: + data (pd.DataFrame or pa.Table): the data chunk to measure + + Returns: + list[int]: list of memory sizes for each row in the chunk + """ + if isinstance(data, pd.DataFrame): + mem_sizes = [_get_row_mem_size_data_frame(row) for row in data.itertuples(index=False, name=None)] + elif isinstance(data, pa.Table): + mem_sizes = [_get_row_mem_size_pa_table(data, i) for i in range(data.num_rows)] + else: + raise NotImplementedError(f"Unsupported data type {type(data)} for memory size calculation") + return mem_sizes + + def split_pixels( input_file: UPath, pickled_reader_file: str, diff --git a/src/hats_import/catalog/resume_plan.py b/src/hats_import/catalog/resume_plan.py index ea0c155a..1871add3 100644 --- a/src/hats_import/catalog/resume_plan.py +++ b/src/hats_import/catalog/resume_plan.py @@ -23,13 +23,16 @@ class ResumePlan(PipelineResumePlan): """Container class for holding the state of each file in the pipeline plan.""" input_paths: list[UPath] = field(default_factory=list) - """resolved list of all files that will be used in the importer""" + """Resolved list of all files that will be used in the importer""" map_files: list[tuple[str, str]] = field(default_factory=list) - """list of files (and job keys) that have yet to be mapped""" + """List of files (and job keys) that have yet to be mapped""" split_keys: list[tuple[str, str]] = field(default_factory=list) - """set of files (and job keys) that have yet to be split""" + """Set of files (and job keys) that have yet to be split""" destination_pixel_map: dict[HealpixPixel, int] | None = None """Destination pixels and their expected final count""" + threshold_mode: str = "row_count" + """Which mode to use for partitioning: 'row_count' or 'mem_size'. + Determines whether to create additional mem_size histogram.""" should_run_mapping: bool = True should_run_splitting: bool = True should_run_reducing: bool = True @@ -41,6 +44,8 @@ class ResumePlan(PipelineResumePlan): ROW_COUNT_HISTOGRAM_BINARY_FILE = "row_count_mapping_histogram.npz" ROW_COUNT_HISTOGRAMS_DIR = "row_count_histograms" + MEM_SIZE_HISTOGRAM_BINARY_FILE = "mem_size_mapping_histogram.npz" + MEM_SIZE_HISTOGRAMS_DIR = "mem_size_histograms" ALIGNMENT_FILE = "alignment.pickle" @@ -63,6 +68,10 @@ def __init__( if import_args.debug_stats_only: run_stages = ["mapping", "finishing"] self.input_paths = import_args.input_paths + + # Set threshold_mode based on byte_pixel_threshold + if hasattr(import_args, "byte_pixel_threshold") and import_args.byte_pixel_threshold is not None: + self.threshold_mode = "mem_size" else: super().__init__( resume=resume, @@ -118,6 +127,13 @@ def gather_plan(self, run_stages: list[str] | None = None): file_io.append_paths_to_pointer(self.tmp_path, self.ROW_COUNT_HISTOGRAMS_DIR), exist_ok=True, ) + # If using mem_size thresholding, gather those keys too. + if self.threshold_mode == "mem_size": + self.get_remaining_map_keys(which_histogram="mem_size") + file_io.make_directory( + file_io.append_paths_to_pointer(self.tmp_path, self.MEM_SIZE_HISTOGRAMS_DIR), + exist_ok=True, + ) if self.should_run_splitting: if not (mapping_done or self.should_run_mapping): raise ValueError("mapping must be complete before splitting") @@ -139,50 +155,99 @@ def gather_plan(self, run_stages: list[str] | None = None): ) step_progress.update(1) - def get_remaining_map_keys(self): + def get_remaining_map_keys(self, which_histogram: str = "row_count"): """Gather remaining keys, dropping successful mapping tasks from histogram names. + Args: + which_histogram (str): Which histogram to check for completed tasks, either 'row_count' + or 'mem_size'. Defaults to 'row_count'. + Returns: - list of mapping keys *not* found in files like /resume/path/mapping_key.npz + list of tuple: The mapping keys *not* found in files like /resume/path/mapping_key.npz, + along with their corresponding input file paths. + + Raises: + ValueError: If `which_histogram` is not recognized, or if which_histogram is + 'mem_size' but the threshold_mode is not 'mem_size'. """ - prefix = file_io.get_upath(self.tmp_path) / self.ROW_COUNT_HISTOGRAMS_DIR + if which_histogram == "row_count": + prefix = file_io.get_upath(self.tmp_path) / self.ROW_COUNT_HISTOGRAMS_DIR + elif which_histogram == "mem_size" and self.threshold_mode == "mem_size": + prefix = file_io.get_upath(self.tmp_path) / self.MEM_SIZE_HISTOGRAMS_DIR + elif which_histogram == "mem_size": + raise ValueError("Cannot get remaining mem_size map keys when threshold_mode is not 'mem_size'.") + else: + raise ValueError(f"Unrecognized which_histogram value: {which_histogram}") + map_file_pattern = re.compile(r"map_(\d+).npz") - done_indexes = [int(map_file_pattern.match(path.name).group(1)) for path in prefix.glob("*.npz")] + done_indexes = [ + int(match.group(1)) + for path in prefix.glob("*.npz") + if (match := map_file_pattern.match(path.name)) + ] remaining_indexes = list(set(range(0, len(self.input_paths))) - (set(done_indexes))) return [(f"map_{key}", self.input_paths[key]) for key in remaining_indexes] - def read_histogram(self, healpix_order): - """Return histogram with healpix_order'd shape + def read_histogram(self, healpix_order, which_histogram: str = "row_count"): + """Returns a histogram with the specified Healpix order's shape. - - Try to find a combined histogram - - Otherwise, combine histograms from partials - - Otherwise, return an empty histogram - """ - file_name = file_io.append_paths_to_pointer(self.tmp_path, self.ROW_COUNT_HISTOGRAM_BINARY_FILE) + This method attempts the following steps in order: + 1. Tries to locate and return a combined histogram. + 2. If a combined histogram is unavailable, combines partial histograms to create one. + 3. If no partial histograms are found, returns an empty histogram. + + Args: + healpix_order (int): The desired Healpix order for the histogram. + which_histogram (str): Which histogram to read, either "row_count" or "mem_size". + Defaults to "row_count". - # Otherwise, read the histogram from partial histograms and combine. + + Returns: + numpy.ndarray: A one-dimensional array representing the histogram with the + specified Healpix order. + + Raises: + RuntimeError: If there are incomplete mapping stages. + ValueError: If the histogram from the previous execution is incompatible with + the highest Healpix order, or if `which_histogram` is invalid. + """ + # Temporarily skipping the next block from code coverage, as we will be calling it once we + # address the binning stage and onwards in the next PR. + # Mocking this in the meantime for temporary tests doesn't seem worth the complication). + if which_histogram == "row_count": + histogram_binary_file = self.ROW_COUNT_HISTOGRAM_BINARY_FILE + histogram_directory = self.ROW_COUNT_HISTOGRAMS_DIR + elif which_histogram == "mem_size" and self.threshold_mode == "mem_size": # pragma: no cover + histogram_binary_file = self.MEM_SIZE_HISTOGRAM_BINARY_FILE + histogram_directory = self.MEM_SIZE_HISTOGRAMS_DIR + elif which_histogram == "mem_size": # pragma: no cover + raise ValueError("Cannot read mem_size histogram when threshold_mode is not 'mem_size'.") + else: # pragma: no cover + raise ValueError(f"Unrecognized which_histogram value: {which_histogram}") + + file_name = file_io.append_paths_to_pointer(self.tmp_path, histogram_binary_file) + + # If no file, read the histogram from partial histograms and combine. if not file_io.does_file_or_directory_exist(file_name): remaining_map_files = self.get_remaining_map_keys() if len(remaining_map_files) > 0: raise RuntimeError(f"{len(remaining_map_files)} map stages did not complete successfully.") - histogram_files = file_io.find_files_matching_path( - self.tmp_path, self.ROW_COUNT_HISTOGRAMS_DIR, "*.npz" - ) + histogram_files = file_io.find_files_matching_path(self.tmp_path, histogram_directory, "*.npz") aggregate_histogram = HistogramAggregator(healpix_order) for partial_file_name in histogram_files: partial = SparseHistogram.from_file(partial_file_name) aggregate_histogram.add(partial) - file_name = file_io.append_paths_to_pointer(self.tmp_path, self.ROW_COUNT_HISTOGRAM_BINARY_FILE) - with open(file_name, "wb+") as file_handle: + file_name = file_io.append_paths_to_pointer(self.tmp_path, histogram_binary_file) + with file_name.open("wb+") as file_handle: file_handle.write(aggregate_histogram.full_histogram) if self.delete_resume_log_files: file_io.remove_directory( - file_io.append_paths_to_pointer(self.tmp_path, self.ROW_COUNT_HISTOGRAMS_DIR), + file_io.append_paths_to_pointer(self.tmp_path, histogram_directory), ignore_errors=True, ) - with open(file_name, "rb") as file_handle: + with file_name.open("rb") as file_handle: full_histogram = frombuffer(file_handle.read(), dtype=np.int64) if len(full_histogram) != hp.order2npix(healpix_order): @@ -194,7 +259,7 @@ def read_histogram(self, healpix_order): return full_histogram @classmethod - def partial_histogram_file(cls, tmp_path, mapping_key: str): + def partial_histogram_file(cls, tmp_path, mapping_key: str, which_histogram: str = "row_count"): """File name for writing a histogram file to a special intermediate directory. As a side effect, this method may create the special intermediate directory. @@ -202,12 +267,24 @@ def partial_histogram_file(cls, tmp_path, mapping_key: str): Args: tmp_path (str): where to write intermediate resume files. mapping_key (str): unique string for each mapping task (e.g. "map_57") + which_histogram (str): which histogram to write, either "row_count" or "mem_size". + Defaults to "row_count". + + Returns: + str: Full path to the partial histogram file. """ + if which_histogram == "row_count": + histograms_dir = cls.ROW_COUNT_HISTOGRAMS_DIR + elif which_histogram == "mem_size": + histograms_dir = cls.MEM_SIZE_HISTOGRAMS_DIR + else: + raise ValueError(f"Unrecognized which_histogram value: {which_histogram}") + file_io.make_directory( - file_io.append_paths_to_pointer(tmp_path, cls.ROW_COUNT_HISTOGRAMS_DIR), + file_io.append_paths_to_pointer(tmp_path, histograms_dir), exist_ok=True, ) - return file_io.append_paths_to_pointer(tmp_path, cls.ROW_COUNT_HISTOGRAMS_DIR, f"{mapping_key}.npz") + return file_io.append_paths_to_pointer(tmp_path, histograms_dir, f"{mapping_key}.npz") def get_remaining_split_keys(self): """Gather remaining keys, dropping successful split tasks from done file names. diff --git a/src/hats_import/catalog/run_import.py b/src/hats_import/catalog/run_import.py index 34c72d36..c0e6a693 100644 --- a/src/hats_import/catalog/run_import.py +++ b/src/hats_import/catalog/run_import.py @@ -48,12 +48,13 @@ def run(args, client): ra_column=args.ra_column, dec_column=args.dec_column, use_healpix_29=args.use_healpix_29, + threshold_mode=resume_plan.threshold_mode, ) ) resume_plan.wait_for_mapping(futures) with resume_plan.print_progress(total=2, stage_name="Binning") as step_progress: - raw_histogram = resume_plan.read_histogram(args.mapping_healpix_order) + raw_histogram = resume_plan.read_histogram(args.mapping_healpix_order, which_histogram="row_count") total_rows = int(raw_histogram.sum()) if args.expected_total_rows > 0 and args.expected_total_rows != total_rows: raise ValueError( diff --git a/tests/hats_import/catalog/test_argument_validation.py b/tests/hats_import/catalog/test_argument_validation.py index 795ccb66..62b2236f 100644 --- a/tests/hats_import/catalog/test_argument_validation.py +++ b/tests/hats_import/catalog/test_argument_validation.py @@ -295,3 +295,23 @@ def test_no_import_overwrite(small_sky_object_catalog, parquet_shards_dir): output_artifact_name=catalog_name, file_reader="parquet", ) + + +def test_bad_byte_pixel_thresholds(blank_data_dir, tmp_path): + """Test that invalid pixel thresholds raise errors.""" + with pytest.raises(TypeError, match="byte_pixel_threshold must be an integer"): + ImportArguments( + output_artifact_name="catalog", + input_path=blank_data_dir, + file_reader="csv", + output_path=tmp_path, + byte_pixel_threshold=4.2, + ) + with pytest.raises(ValueError, match="byte_pixel_threshold must be non-negative"): + ImportArguments( + output_artifact_name="catalog", + input_path=blank_data_dir, + file_reader="csv", + output_path=tmp_path, + byte_pixel_threshold=-5, + ) diff --git a/tests/hats_import/catalog/test_map_reduce.py b/tests/hats_import/catalog/test_map_reduce.py index 69a2e20d..961f78d9 100644 --- a/tests/hats_import/catalog/test_map_reduce.py +++ b/tests/hats_import/catalog/test_map_reduce.py @@ -13,6 +13,7 @@ import pyarrow as pa import pytest from hats.pixel_math.sparse_histogram import SparseHistogram +from nested_pandas.nestedframe import NestedFrame import hats_import.catalog.map_reduce as mr from hats_import.catalog.file_readers import get_file_reader @@ -85,9 +86,9 @@ def test_read_bad_fileformat(blank_data_file, capsys, tmp_path): assert "No file reader implemented" in captured.out -def read_partial_histogram(tmp_path, mapping_key): +def read_partial_histogram(tmp_path, mapping_key, which_histogram="row_count"): """Helper to read in the former result of a map operation.""" - histogram_file = tmp_path / "row_count_histograms" / f"{mapping_key}.npz" + histogram_file = tmp_path / f"{which_histogram}_histograms" / f"{mapping_key}.npz" hist = SparseHistogram.from_file(histogram_file) return hist.to_array() @@ -683,3 +684,141 @@ def test_reduce_with_sorting_complex(assert_parquet_file_ids, tmp_path): [1206, 1200, 1201, 1309, 1308, 1307, 1402, 1403, 1404, 1505], resort_ids=False, ) + + +@pytest.mark.parametrize("threshold_mode", ["row_count", "mem_size"]) +def test_histogram_file_contents(tmp_path, small_sky_single_file, threshold_mode): + """Test that map_to_pixels writes correct histogram file contents for both row_count and mem_size.""" + + # Run mapping stage (with threshold_mode parameterized). + mr.map_to_pixels( + input_file=small_sky_single_file, + pickled_reader_file=pickle_file_reader(tmp_path, get_file_reader("csv")), + highest_order=0, + ra_column="ra", + dec_column="dec", + resume_path=tmp_path, + mapping_key="map_0", + threshold_mode=threshold_mode, + ) + + # Check histogram contents. + result = read_partial_histogram(tmp_path, "map_0", which_histogram=threshold_mode) + if threshold_mode == "row_count": + assert len(result) == 12 + assert result[11] == 131 # known value for this fixture + assert result.sum() == 131 + elif threshold_mode == "mem_size": + assert len(result) == 12 + # Let's consider a very conservative lower bound for memory size. + # Each row has at least 3 columns: id (int64), ra (float64), dec (float64). + # Each int64 takes 8 bytes, each float64 takes 8 bytes. + # So each row takes at least 8 + 8 + 8 = 24 bytes. + # For 131 rows, minimum memory size is 131 * 24 = 3144 bytes. + # However, since we are using a histogram with bins of size powers of two, + # we expect the mem_size histogram to have a single bin that is at least 3144. + assert result[11] >= 3144 + assert result.sum() >= 3144 + + +def test_get_mem_size_of_chunk(): + """Test the _get_mem_size_of_chunk function for reasonable outputs.""" + # Test with an empty DataFrame + empty_df = pd.DataFrame(columns=["id", "ra", "dec", "value"]) + mem_sizes_empty = mr._get_mem_size_of_chunk(empty_df) + assert len(mem_sizes_empty) == 0 + + # Test with a small DataFrame + df = pd.DataFrame( + { + "id": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2], + "ra": [10.0, 10.0, 10.0, 15.0, 15.0, 15.0, 12.1, 12.1, 12.1, 12.1], + "dec": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0, 0.5, 0.5, 0.5, 0.5], + "time": [ + 60676.0, + 60677.0, + 60678.0, + 60675.0, + 60676.5, + 60677.0, + 60676.6, + 60676.7, + 60676.8, + 60676.9, + ], + "brightness": [100.0, 101.0, 99.8, 5.0, 5.01, 4.98, 20.1, 20.5, 20.3, 20.2], + "band": ["g", "r", "g", "r", "g", "r", "g", "g", "r", "r"], + } + ) + mem_sizes = mr._get_mem_size_of_chunk(df) + # Since we have 10 rows, mem_sizes should have length 10 + assert len(mem_sizes) == 10 + # Each entry should be a positive integer (size in bytes) + assert all(isinstance(size, int) and size > 0 for size in mem_sizes) + + # Compare to a smaller DataFrame with fewer columns + df_small = df[["id", "ra", "dec"]] + mem_sizes_small = mr._get_mem_size_of_chunk(df_small) + assert len(mem_sizes_small) == 10 + assert all(isinstance(size, int) and size > 0 for size in mem_sizes_small) + # Each entry in mem_sizes should be > corresponding entry in mem_sizes_small + assert all(m > s for m, s in zip(mem_sizes, mem_sizes_small, strict=True)) + + # Test with a pyarrow Table + table = pa.Table.from_pandas(df) + mem_sizes_table = mr._get_mem_size_of_chunk(table) + assert len(mem_sizes_table) == 10 + assert all(isinstance(size, int) and size > 0 for size in mem_sizes_table) + + # Test with a smaller pyarrow Table + table_small = pa.Table.from_pandas(df_small) + mem_sizes_table_small = mr._get_mem_size_of_chunk(table_small) + assert len(mem_sizes_table_small) == 10 + assert all(isinstance(size, int) and size > 0 for size in mem_sizes_table_small) + # Each entry in mem_sizes_table should be > corresponding entry in mem_sizes_table_small + assert all(m > s for m, s in zip(mem_sizes_table, mem_sizes_table_small, strict=True)) + + +def test_get_mem_size_of_chunk_nested(): + """Test the _get_mem_size_of_chunk function with nested data.""" + # Create a small DataFrame and nest it + df = pd.DataFrame( + { + "id": [0, 0, 0, 1, 1, 1, 2, 2, 2, 2], + "ra": [10.0, 10.0, 10.0, 15.0, 15.0, 15.0, 12.1, 12.1, 12.1, 12.1], + "dec": [0.0, 0.0, 0.0, -1.0, -1.0, -1.0, 0.5, 0.5, 0.5, 0.5], + "time": [ + 60676.0, + 60677.0, + 60678.0, + 60675.0, + 60676.5, + 60677.0, + 60676.6, + 60676.7, + 60676.8, + 60676.9, + ], + "brightness": [100.0, 101.0, 99.8, 5.0, 5.01, 4.98, 20.1, 20.5, 20.3, 20.2], + "band": ["g", "r", "g", "r", "g", "r", "g", "g", "r", "r"], + } + ) + nf = NestedFrame.from_flat( + df, + base_columns=["ra", "dec"], + nested_columns=["time", "brightness", "band"], + on="id", + name="lightcurve", + ) + + # Calculate memory sizes + mem_sizes = mr._get_mem_size_of_chunk(nf) + + # Since we have only 3 rows once we nest, mem_sizes should have length 3 + assert len(mem_sizes) == 3 + # Each entry should be a positive integer (size in bytes) + assert all(isinstance(size, int) and size > 0 for size in mem_sizes) + # The first two entries should be the same, since they have 3 sub-rows each + assert mem_sizes[0] == mem_sizes[1] + # The last entry should be the larger than the other two, since it has 4 sub-rows + assert mem_sizes[2] > mem_sizes[0] diff --git a/tests/hats_import/catalog/test_resume_plan.py b/tests/hats_import/catalog/test_resume_plan.py index ca8f0535..dc5deb2e 100644 --- a/tests/hats_import/catalog/test_resume_plan.py +++ b/tests/hats_import/catalog/test_resume_plan.py @@ -1,5 +1,7 @@ """Test catalog resume logic""" +from unittest.mock import MagicMock + import numpy as np import numpy.testing as npt import pytest @@ -258,3 +260,68 @@ def test_run_stages(tmp_path): assert not plan.should_run_splitting assert plan.should_run_reducing assert plan.should_run_finishing + + +def test_resume_plan_with_byte_pixel_threshold(tmp_path): + """Test ResumePlan initialization with byte_pixel_threshold specified.""" + # Mock import_args with necessary attributes + import_args = MagicMock() + import_args.resume_kwargs_dict.return_value = {"tmp_path": tmp_path} + import_args.debug_stats_only = False + import_args.byte_pixel_threshold = 100 # Simulate a specified byte_pixel_threshold + import_args.input_paths = ["file1", "file2"] + + # Initialize ResumePlan with the mocked import_args + resume_plan = ResumePlan(import_args=import_args) + + # Assert that threshold_mode is set to "mem_size" + assert resume_plan.threshold_mode == "mem_size" + + +def test_gather_plan_mem_size_mode(tmp_path): + """Test gather_plan in mem_size mode.""" + # Mock import_args with necessary attributes + import_args = MagicMock() + import_args.resume_kwargs_dict.return_value = {"tmp_path": tmp_path} + import_args.debug_stats_only = False + import_args.byte_pixel_threshold = 100 # Simulate a specified byte_pixel_threshold + import_args.input_paths = ["file1", "file2"] + + # Initialize ResumePlan with the mocked import_args + plan = ResumePlan(import_args=import_args, tmp_path=tmp_path) + + # Mock methods to avoid actual file operations + plan.get_remaining_map_keys = MagicMock(return_value=[("map_1", "file1"), ("map_2", "file2")]) + plan.done_file_exists = MagicMock(return_value=False) + plan.check_original_input_paths = MagicMock(return_value=import_args.input_paths) + plan.print_progress = MagicMock() + + # Call gather_plan and verify mem_size directory creation + plan.gather_plan() + plan.get_remaining_map_keys.assert_any_call(which_histogram="mem_size") + + +def test_get_remaining_map_keys_mem_size_and_invalid(tmp_path): + """Test get_remaining_map_keys with mem_size and invalid options.""" + # Create a ResumePlan instance with mem_size threshold_mode + plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, input_paths=["file1", "file2"]) + plan.threshold_mode = "mem_size" + + # Mock the directory structure to simulate mem_size histogram files + mem_size_dir = tmp_path / plan.MEM_SIZE_HISTOGRAMS_DIR + mem_size_dir.mkdir() + (mem_size_dir / "map_0.npz").touch() + + # Call get_remaining_map_keys with mem_size + remaining_keys = plan.get_remaining_map_keys(which_histogram="mem_size") + assert len(remaining_keys) == 1 + assert remaining_keys[0] == ("map_1", "file2") + + # Raise error if threshold_mode is not mem_size + plan.threshold_mode = "row_count" + with pytest.raises(ValueError, match="when threshold_mode is not 'mem_size'"): + plan.get_remaining_map_keys(which_histogram="mem_size") + + # Call get_remaining_map_keys with an invalid option + with pytest.raises(ValueError, match="Unrecognized which_histogram value"): + plan.get_remaining_map_keys(which_histogram="invalid_option") diff --git a/tests/hats_import/catalog/test_run_import.py b/tests/hats_import/catalog/test_run_import.py index 996b52a9..b18f9d16 100644 --- a/tests/hats_import/catalog/test_run_import.py +++ b/tests/hats_import/catalog/test_run_import.py @@ -424,3 +424,32 @@ def test_import_with_npix_dir(dask_client, small_sky_parts_dir, tmp_path, assert # The file exists and contains the expected object IDs output_file = pix_dir / "0.parquet" assert_parquet_file_ids(output_file, "id", expected_ids) + + +@pytest.mark.dask +def test_mem_size_thresholding( + small_sky_parts_dir, + tmp_path, + dask_client, +): + """Test that we can run with mem_size thresholding.""" + args = ImportArguments( + output_artifact_name="small_sky_mem_size", + input_path=small_sky_parts_dir, + file_reader="csv", + output_path=tmp_path, + dask_tmp=tmp_path, + tmp_dir=tmp_path, + highest_healpix_order=1, + progress_bar=False, + pixel_threshold=10_000, + debug_stats_only=True, + run_stages=["mapping"], + ) + + runner.run(args, dask_client) + + # Check that the catalog metadata file exists + catalog = read_hats(args.catalog_path) + assert catalog.on_disk + assert catalog.catalog_path == args.catalog_path