Skip to content

Commit 4332871

Browse files
claude refactoring of masking in metacal
1 parent 9a39d94 commit 4332871

File tree

1 file changed

+133
-18
lines changed

1 file changed

+133
-18
lines changed

src/sp_validation/basic.py

Lines changed: 133 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -359,27 +359,23 @@ def _masking_gal(self):
359359
[self.ns, self.m1, self.p1, self.m2, self.p2],
360360
['ns', 'm1', 'p1', 'm2', 'p2']
361361
):
362-
Tr_tmp = data['T']
363-
if self._size_corr_ell:
364-
Tr_tmp *= (
365-
(1 - (data['g1'] ** 2 + data['g2'] ** 2))
366-
/ (1 + (data['g1'] ** 2 + data['g2'] ** 2))
367-
)
362+
# Add SNR if available from sextractor
368363
if hasattr(self, 'snr_sextractor'):
369-
snr_flux = self.snr_sextractor
370-
else:
371-
snr_flux = data['flux'] / data['flux_err']
372-
373-
mask_tmp = (
374-
(data['flag'] == 0)
375-
& (Tr_tmp / data['Tpsf'] > self._rel_size_min)
376-
& (Tr_tmp / data['Tpsf'] < self._rel_size_max)
377-
& (snr_flux > self._snr_min)
378-
& (snr_flux < self._snr_max)
364+
data['snr_flux'] = self.snr_sextractor
365+
366+
# Get masks using standalone function
367+
# Data dict already has the required keys: g1, g2, T, Tpsf, flag, flux, flux_err
368+
masks = get_metacal_masks_separate(
369+
data,
370+
snr_min=self._snr_min,
371+
snr_max=self._snr_max,
372+
rel_size_min=self._rel_size_min,
373+
rel_size_max=self._rel_size_max,
374+
size_corr_ell=self._size_corr_ell,
379375
)
380376

381-
# Take care of rotated version
382-
ind_masked = np.where(mask_tmp == True)[0]
377+
# Extract indices from combined mask
378+
ind_masked = np.where(masks['combined'] == True)[0]
383379

384380
self.mask_dict[name] = ind_masked
385381

@@ -389,6 +385,7 @@ def _masking_gal_mom(self):
389385
...
390386
391387
"""
388+
print("Masking of galaxies with moments measurements is deprecated")
392389
self.mask_dict = {}
393390
for data, name in zip(
394391
[self.ns, self.m1, self.p1, self.m2, self.p2],
@@ -562,6 +559,124 @@ def _return():
562559
)
563560

564561

562+
def get_metacal_masks_separate(
563+
data,
564+
prefix='NGMIX',
565+
snr_min=10,
566+
snr_max=500,
567+
rel_size_min=0.5,
568+
rel_size_max=3.0,
569+
size_corr_ell=True,
570+
shear_variant='NOSHEAR',
571+
col_2d=True,
572+
):
573+
"""Get Metacal Masks Separate.
574+
575+
Get separate masks for metacalibration quality cuts without running
576+
full metacal computation.
577+
578+
Parameters
579+
----------
580+
data : array or Table or dict
581+
input galaxy catalogue, or dictionary with keys 'T', 'Tpsf',
582+
'flag', and either 'flux'/'flux_err' or 'snr_flux'.
583+
If size_corr_ell=True, also requires 'g1' and 'g2'.
584+
prefix : str, optional, default='NGMIX'
585+
column name prefix in catalogue (ignored if data is dict)
586+
snr_min : float, optional, default=10
587+
signal-to-noise minimum
588+
snr_max : float, optional, default=500
589+
signal-to-noise maximum
590+
rel_size_min : float, optional, default=0.5
591+
relative size minimum
592+
rel_size_max : float, optional, default=3.0
593+
relative size maximum
594+
size_corr_ell : bool, optional, default=True
595+
if True, correct size for ellipticity
596+
shear_variant : str, optional, default='NOSHEAR'
597+
shear variant name, e.g. 'NOSHEAR', '1M', '1P', '2M', '2P'
598+
(ignored if data is dict)
599+
col_2d : bool, optional, default=True
600+
if True, ellipticity in one 2D column; if False, in two columns
601+
(ignored if data is dict)
602+
603+
Returns
604+
-------
605+
dict
606+
Dictionary containing separate masks:
607+
- 'flag': boolean mask for flag == 0
608+
- 'rel_size': boolean mask for relative size cuts
609+
- 'snr': boolean mask for SNR cuts
610+
- 'combined': boolean mask for all cuts combined
611+
612+
"""
613+
# Handle dictionary input (from metacal class)
614+
if isinstance(data, dict):
615+
T = data['T']
616+
Tpsf = data['Tpsf']
617+
flag = data['flag']
618+
if 'snr_flux' in data:
619+
snr_flux = data['snr_flux']
620+
else:
621+
snr_flux = data['flux'] / data['flux_err']
622+
623+
# Only extract g1, g2 if needed for size correction
624+
if size_corr_ell:
625+
g1 = data['g1']
626+
g2 = data['g2']
627+
else:
628+
# Handle catalogue input (raw data)
629+
T = data[f'{prefix}_T_{shear_variant}']
630+
Tpsf = data[f'{prefix}_Tpsf_{shear_variant}']
631+
flag = data[f'{prefix}_FLAGS_{shear_variant}']
632+
633+
# Get SNR
634+
if f'SNR_WIN' in data.dtype.names or (hasattr(data, 'colnames') and 'SNR_WIN' in data.colnames):
635+
snr_flux = data['SNR_WIN']
636+
else:
637+
flux = data[f'{prefix}_FLUX_{shear_variant}']
638+
flux_err = data[f'{prefix}_FLUX_ERR_{shear_variant}']
639+
snr_flux = flux / flux_err
640+
641+
# Only extract g1, g2 if needed for size correction
642+
if size_corr_ell:
643+
if col_2d:
644+
g1 = data[f'{prefix}_ELL_{shear_variant}'][:, 0]
645+
g2 = data[f'{prefix}_ELL_{shear_variant}'][:, 1]
646+
else:
647+
g1 = data[f'{prefix}_ELL_{shear_variant}_0']
648+
g2 = data[f'{prefix}_ELL_{shear_variant}_1']
649+
650+
# Apply size correction for ellipticity if requested
651+
Tr_tmp = T.copy() if hasattr(T, 'copy') else np.array(T)
652+
if size_corr_ell:
653+
Tr_tmp *= (
654+
(1 - (g1 ** 2 + g2 ** 2))
655+
/ (1 + (g1 ** 2 + g2 ** 2))
656+
)
657+
658+
# Create separate masks
659+
mask_flag = (flag == 0)
660+
mask_rel_size = (
661+
(Tr_tmp / Tpsf > rel_size_min)
662+
& (Tr_tmp / Tpsf < rel_size_max)
663+
)
664+
mask_snr = (
665+
(snr_flux > snr_min)
666+
& (snr_flux < snr_max)
667+
)
668+
669+
# Combined mask
670+
mask_combined = mask_flag & mask_rel_size & mask_snr
671+
672+
return {
673+
'flag': mask_flag,
674+
'rel_size': mask_rel_size,
675+
'snr': mask_snr,
676+
'combined': mask_combined,
677+
}
678+
679+
565680
def jackknif_weighted_average2(
566681
data,
567682
weights,

0 commit comments

Comments
 (0)