Skip to content

Commit 5ce3d93

Browse files
authored
Merge pull request #629 from lincc-frameworks/lsst-band-auto-detect-fix
Fix _get_available_bands_from_manifest to find complete band entries
2 parents 37b3832 + dd0a4df commit 5ce3d93

File tree

5 files changed

+100
-16
lines changed

5 files changed

+100
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ dev = [
8585
"sphinx-togglebutton",
8686
"sphinx-rtd-theme",
8787
"lsdb", # Used to test lsst dataset classes
88+
"cdshealpix <= 0.7.1",
8889
]
8990

9091
[build-system]

src/hyrax/data_sets/downloaded_lsst_dataset.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -371,17 +371,33 @@ def _longest_object_id_idx(self):
371371
return np.argmax([len(str(id)) for id in object_ids])
372372

373373
def _get_available_bands_from_manifest(self, manifest):
374-
"""Best effort to get available bands by looking at first 10 successful downloads for consistency."""
374+
"""Get available bands by finding entries with complete band coverage.
375+
376+
Uses cutout_shape[0] to determine the expected number of bands, then finds
377+
entries where downloaded_bands has that many entries (i.e., complete downloads).
378+
"""
375379
if len(manifest) == 0:
376380
return None, None
377381

378-
successful_entries = []
382+
# First, find the expected number of bands from cutout_shape
383+
# Look for the first entry with a valid cutout_shape
384+
expected_band_count = None
385+
for i in range(min(len(manifest), 1000)):
386+
shape = manifest["cutout_shape"][i]
387+
if shape is not None and len(shape) > 0 and shape[0] > 0:
388+
expected_band_count = shape[0]
389+
break
390+
391+
if expected_band_count is None:
392+
# No valid cutout_shape found
393+
return None, None
379394

380-
# Attempt to find first 10 successful downloads.
381-
# For long manifests (e.g. 1 million undownloaded cutouts), avoid iterating too far to find these 10.
395+
# Now find first 5 entries where downloaded_bands has the expected count
396+
complete_entries = []
382397
give_up_idx = min(len(manifest), 1000)
398+
383399
for i in range(give_up_idx):
384-
if len(successful_entries) >= 10:
400+
if len(complete_entries) >= 5:
385401
break
386402

387403
filename = manifest["filename"][i]
@@ -395,19 +411,26 @@ def _get_available_bands_from_manifest(self, manifest):
395411
and str(downloaded_bands_str).strip()
396412
):
397413
bands = [b.strip() for b in str(downloaded_bands_str).split(",") if b.strip()]
398-
if bands: # Non-empty band list
399-
successful_entries.append(bands)
400-
401-
if not successful_entries:
402-
return None, None
414+
# Only include entries with complete band coverage
415+
if len(bands) == expected_band_count:
416+
complete_entries.append(bands)
417+
418+
if not complete_entries:
419+
raise RuntimeError(
420+
f"We checked the first 1000 manifest entries and found no entries with complete band"
421+
f"coverage. Expected {expected_band_count} bands based on cutout_shape, but less than 5"
422+
f"downloaded entries have all bands present. Cannot automatically determine consistent"
423+
f"band structure."
424+
)
403425

404-
# Check that all successful entries have identical band lists
405-
first_bands = successful_entries[0]
406-
for i, bands in enumerate(successful_entries[1:], 1):
426+
# Check that all complete entries have identical band lists
427+
first_bands = complete_entries[0]
428+
for i, bands in enumerate(complete_entries[1:], 1):
407429
if bands != first_bands:
408430
raise RuntimeError(
409-
f"Inconsistent band ordering in manifest. Entry 0 has {first_bands}, "
410-
f"but entry {i} has {bands}. Cannot determine consistent band structure."
431+
f"Inconsistent band ordering in manifest among complete downloads. "
432+
f"Entry 0 has {first_bands}, but entry {i} has {bands}. "
433+
f"Cannot determine consistent band structure."
411434
)
412435

413436
return set(first_bands), first_bands

tests/hyrax/mocks/lsst_butler_mocks.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,17 @@ class MockButler:
403403
band_fail_prob = {}
404404
fail_after_n = 0
405405
band_fail_after_n = {}
406+
band_fail_before_n = {}
406407

407408
@classmethod
408-
def reset(cls, fail_prob=0.0, band_fail_prob=None, fail_after_n=0, band_fail_after_n=None):
409+
def reset(
410+
cls,
411+
fail_prob=0.0,
412+
band_fail_prob=None,
413+
fail_after_n=0,
414+
band_fail_after_n=None,
415+
band_fail_before_n=None,
416+
):
409417
"""Resets the mock butler for a new test, and configures failure behavior
410418
411419
Parameters
@@ -423,12 +431,17 @@ def reset(cls, fail_prob=0.0, band_fail_prob=None, fail_after_n=0, band_fail_aft
423431
Continually fail particular band(s) after the provided number of calls to butler.get in the
424432
particular band. Dictionary provided has bands as keys and counts as values.
425433
Counts of zero mean no failures for that band
434+
band_fail_before_n : dict, optional
435+
Fail particular band(s) for the first N calls, then succeed. Dictionary provided has bands
436+
as keys and counts as values. For example band_fail_before_n={"g": 5} would cause the
437+
first 5 gets to g band to fail, then succeed afterwards.
426438
"""
427439
cls.initialized_thread_ids = []
428440
cls.fail_prob = fail_prob
429441
cls.band_fail_prob = {} if band_fail_prob is None else band_fail_prob
430442
cls.fail_after_n = fail_after_n
431443
cls.band_fail_after_n = {} if band_fail_after_n is None else band_fail_after_n
444+
cls.band_fail_before_n = {} if band_fail_before_n is None else band_fail_before_n
432445

433446
def __init__(self, repo=None, collections=None):
434447
"""Initialize mock butler.
@@ -441,6 +454,7 @@ def __init__(self, repo=None, collections=None):
441454
self._collections = collections
442455
self.request_count = 0
443456
self.band_request_count = {}
457+
self.band_attempt_count = {}
444458

445459
# Ensure only one Mock Butler per thread
446460
thread_id = threading.current_thread().ident
@@ -456,6 +470,12 @@ def __init__(self, repo=None, collections=None):
456470
self._data = {}
457471

458472
def _generate_errors(self, rng, band):
473+
# Track attempts (before any failures) for band_fail_before_n
474+
if self.band_attempt_count.get(band) is None:
475+
self.band_attempt_count[band] = 1
476+
else:
477+
self.band_attempt_count[band] += 1
478+
459479
if MockButler.fail_after_n != 0 and self.request_count >= MockButler.fail_after_n:
460480
msg = f"MockButler: Simulated fail after {self.request_count} requests."
461481
raise RuntimeError(msg)
@@ -469,6 +489,11 @@ def _generate_errors(self, rng, band):
469489
msg = f"MockButler: Simulated fail after {band_limit} requests to {band} band."
470490
raise RuntimeError(msg)
471491

492+
band_fail_before = MockButler.band_fail_before_n.get(band, 0)
493+
if band_fail_before != 0 and self.band_attempt_count.get(band, 0) <= band_fail_before:
494+
msg = f"MockButler: Simulated fail for first {band_fail_before} requests to {band} band."
495+
raise RuntimeError(msg)
496+
472497
band_fail_prob = MockButler.band_fail_prob.get(band, 0.0)
473498
if rng.random() > 1.0 - band_fail_prob:
474499
msg = f"MockButler: Simulated fail due to band failure probability {band} = {band_fail_prob}"

tests/hyrax/test_downloaded_lsst_dataset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import mocks
1010
import pytest
1111
import torch
12+
import torchvision # noqa: F401 # Import before mock contexts to prevent kernel re-registration
1213
from mocks import lsst_config, mock_lsst_environment, sample_catalog, sample_catalog_saved # noqa: F401
1314

1415
from hyrax.data_sets.downloaded_lsst_dataset import DownloadedLSSTDataset
@@ -427,6 +428,39 @@ def test_failed_band_download(mock_lsst_environment, lsst_config, tmp_path): #
427428
assert torch.all(cutout[2] == cutout[2])
428429

429430

431+
def test_band_detection_with_partial_downloads(mock_lsst_environment, lsst_config, tmp_path): # noqa: F811
432+
"""
433+
Test that _get_available_bands_from_manifest correctly identifies bands
434+
from complete downloads, ignoring partial downloads that may appear earlier
435+
in the manifest.
436+
"""
437+
# Configure 4 bands
438+
lsst_config["data_set"]["filters"] = ["g", "r", "i", "z"]
439+
440+
with mock_lsst_environment():
441+
# Make g and r bands fail for the FIRST 5 downloads each, then succeed
442+
# Early entries will have only i,z (partial), later entries will have all 4
443+
dataset = DownloadedLSSTDatasetMocked(
444+
lsst_config,
445+
data_location=str(tmp_path),
446+
patcher=mock_lsst_environment,
447+
patcher_kwargs={"band_fail_before_n": {"g": 5, "r": 5}},
448+
)
449+
_manifest = dataset.download_cutouts()
450+
451+
# Request only g,r,i bands - triggers _get_available_bands_from_manifest
452+
# which must find complete 4-band entries to determine available bands
453+
lsst_config["data_set"]["filters"] = ["g", "r", "i"]
454+
455+
dataset = DownloadedLSSTDatasetMocked(
456+
lsst_config, data_location=str(tmp_path), patcher=mock_lsst_environment
457+
)
458+
459+
# Verify band filtering found complete entries and set up correctly
460+
assert dataset._is_filtering_bands is True
461+
assert set(dataset.BANDS) == {"g", "r", "i"}
462+
463+
430464
def test_catalog_ordering(mock_lsst_environment, lsst_config, tmp_path, sample_catalog): # noqa: F811
431465
"""
432466
Test that after a download the ordering of a new dataset object is given in the same order

tests/hyrax/test_lsst_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import mocks
1212
import torch
13+
import torchvision # noqa: F401 # Import before mock contexts to prevent kernel re-registration
1314
from mocks import lsst_config, mock_lsst_environment, sample_catalog, sample_catalog_saved # noqa: F401
1415

1516

0 commit comments

Comments
 (0)