Skip to content

Commit d8f75af

Browse files
committed
* fix tests
* update sky loss experiment * fix ionosphere syntax error
1 parent ccbbdeb commit d8f75af

File tree

4 files changed

+255
-93
lines changed

4 files changed

+255
-93
lines changed

dsa2000_cal/notebooks/explore_sky_loss.ipynb

Lines changed: 196 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
"import glob\n",
4646
"import os\n",
4747
"import time\n",
48+
"from typing import NamedTuple, Dict\n",
4849
"\n",
4950
"from tqdm import tqdm\n",
5051
"\n",
@@ -68,7 +69,7 @@
6869
"from dsa2000_assets.registries import source_model_registry, array_registry, misc_registry\n",
6970
"from dsa2000_common.common.array_types import FloatArray\n",
7071
"from dsa2000_common.common.astropy_utils import create_spherical_spiral_grid, get_time_of_local_meridean\n",
71-
"from dsa2000_common.common.noise import calc_image_noise\n",
72+
"from dsa2000_common.common.noise import calc_image_noise, calc_baseline_noise\n",
7273
"from dsa2000_common.common.quantity_utils import time_to_jnp, quantity_to_jnp\n",
7374
"from dsa2000_common.common.serialise_utils import SerialisableBaseModel\n",
7475
"from dsa2000_common.delay_models.base_far_field_delay_engine import build_far_field_delay_engine, \\\n",
@@ -86,13 +87,45 @@
8687
"import pylab as plt\n",
8788
"\n",
8889
"\n",
89-
"@partial(jax.jit)\n",
90-
"def compute_sky_values(l, m, n,\n",
91-
" bright_sky_model: BasePointSourceModel, total_gain_model: GainModel,\n",
92-
" times: FloatArray, far_field_delay_engine: BaseFarFieldDelayEngine,\n",
93-
" geodesic_model: BaseGeodesicModel, freqs: FloatArray,\n",
94-
" zero_point: FloatArray, integration_time: FloatArray, channel_width: FloatArray,\n",
95-
" ra0, dec0):\n",
90+
"class ValuesAndRMS(NamedTuple):\n",
91+
" rms_no_noise: FloatArray\n",
92+
" max_no_noise: FloatArray\n",
93+
" min_no_noise: FloatArray\n",
94+
" mean_no_noise: FloatArray\n",
95+
" std_no_noise: FloatArray\n",
96+
" rms_noise: FloatArray\n",
97+
" max_noise: FloatArray\n",
98+
" min_noise: FloatArray\n",
99+
" mean_noise: FloatArray\n",
100+
" std_noise: FloatArray\n",
101+
"\n",
102+
" mean_abs_R: FloatArray\n",
103+
" std_abs_R: FloatArray\n",
104+
" mean_time_smear: FloatArray\n",
105+
" std_time_smear: FloatArray\n",
106+
" mean_freq_smear: FloatArray\n",
107+
" std_freq_smear: FloatArray\n",
108+
" mean_smear: FloatArray\n",
109+
" std_smear: FloatArray\n",
110+
"\n",
111+
"\n",
112+
"@partial(jax.jit, static_argnames=['smearing'])\n",
113+
"def compute_rms_and_values(\n",
114+
" key,\n",
115+
" l: FloatArray, m: FloatArray, n: FloatArray,\n",
116+
" bright_sky_model: BasePointSourceModel,\n",
117+
" total_gain_model: GainModel,\n",
118+
" times: FloatArray,\n",
119+
" far_field_delay_engine: BaseFarFieldDelayEngine,\n",
120+
" geodesic_model: BaseGeodesicModel,\n",
121+
" freqs: FloatArray,\n",
122+
" zero_point: FloatArray,\n",
123+
" integration_time: FloatArray,\n",
124+
" channel_width: FloatArray,\n",
125+
" ra0: FloatArray, dec0: FloatArray,\n",
126+
" baseline_noise: FloatArray,\n",
127+
" smearing: bool\n",
128+
"):\n",
96129
" \"\"\"\n",
97130
" Compute the RMS in the field due to uncalibrated extra-field sources.\n",
98131
"\n",
@@ -137,38 +170,116 @@
137170
" A = jnp.where(elevation[0, 0, :] > 0, jnp.mean(bright_sky_model.A, axis=1), 0.) # [N]\n",
138171
" c = quantity_to_jnp(const.c)\n",
139172
"\n",
140-
" def add_square_sum(image_sum, x):\n",
173+
" phase_eval = (jnp.sum(lmn_eval[:, None, ...] * uvw[None, ...], axis=-1) - w[None, :]) # [M, B]\n",
174+
" phase = -(jnp.sum(lmn_sources[None, ...] * uvw[:, None, :], axis=-1) - w[:, None]) # [B, N]\n",
175+
" phase_dt = -(jnp.sum(lmn_sources[None, ...] * uvw_dt[:, None, :], axis=-1) - w_dt[:, None]) # [B, N]\n",
176+
" R = (phase_dt - phase) / (0.5 * integration_time) # [B, N]\n",
177+
"\n",
178+
" mean_abs_R = jnp.mean(jnp.abs(R))\n",
179+
" std_abs_R = jnp.std(jnp.abs(R))\n",
180+
"\n",
181+
" def sinc(x):\n",
182+
" return jax.lax.div(jnp.sin(x), x)\n",
183+
"\n",
184+
" def accumulate_over_freq(accumulate, x):\n",
141185
" # scalar -> [B] -> [B, M] -> scalar\n",
142-
" (freq,) = x # scalar\n",
186+
" (freq, key) = x # scalar\n",
143187
"\n",
144188
" gains = total_gain_model.compute_gain(freq[None], times, lmn_geodesic) # [T, A, 1, N]\n",
145189
" g1 = gains[0, visibilty_coords.antenna1, 0, :] # [B, N]\n",
146190
" g2 = gains[0, visibilty_coords.antenna2, 0, :] # [B, N]\n",
147191
" wavelength = c / freq\n",
148192
" coeff = 2 * jnp.pi / wavelength\n",
149-
" # vis\n",
150-
" phase = -(jnp.sum(lmn_sources[None, ...] * uvw[:, None, :], axis=-1) - w[:, None]) # [B, N]\n",
151-
" phase_dt = -(jnp.sum(lmn_sources[None, ...] * uvw_dt[:, None, :], axis=-1) - w_dt[:, None]) # [B, N]\n",
152-
" R = (phase_dt - phase) / (0.5 * integration_time)\n",
153-
" time_smear_modulation = jnp.sinc(R * integration_time / wavelength) # [B, N]\n",
154-
" freq_smear_moduluation = jnp.sinc(phase * channel_width / c) # [B, N]\n",
155-
" smear_modulation = time_smear_modulation * freq_smear_moduluation # [B, N]\n",
156-
" fringe = jax.lax.complex(jnp.cos(coeff * phase), jnp.sin(coeff * phase)).astype(jnp.complex64)\n",
157-
" vis = jnp.sum((g1 * g2.conj()) * (A * (fringe * freq_smear_moduluation)), axis=1).astype(jnp.float32) # [B]\n",
158-
" # image\n",
159-
" phase = (jnp.sum(lmn_eval[:, None, ...] * uvw[None, ...], axis=-1) - w[None, :]) # [M, B]\n",
160-
" fringe = jax.lax.complex(jnp.cos(coeff * phase), jnp.sin(coeff * phase)).astype(jnp.complex64)\n",
161-
" n = lmn_eval[..., 2].astype(jnp.float32) # [M]\n",
162-
" image = n * jnp.sum((vis * fringe).real, axis=1) # [M]\n",
193+
" # compute vis with optional smearing\n",
194+
" fringe = jax.lax.complex(jnp.cos(coeff * phase), jnp.sin(coeff * phase)).astype(jnp.complex64) # [B, N]\n",
195+
" if smearing:\n",
196+
" time_smear_modulation = sinc(R * (jnp.pi * integration_time / wavelength)) # [B, N]\n",
197+
" freq_smear_moduluation = sinc(phase * (jnp.pi * channel_width / c)) # [B, N]\n",
198+
" smear_modulation = time_smear_modulation * freq_smear_moduluation # [B, N]\n",
199+
" mean_time_smear = jnp.mean(time_smear_modulation)\n",
200+
" std_time_smear = jnp.std(time_smear_modulation)\n",
201+
" mean_freq_smear = jnp.mean(freq_smear_moduluation)\n",
202+
" std_freq_smear = jnp.std(freq_smear_moduluation)\n",
203+
" mean_smear = jnp.mean(smear_modulation)\n",
204+
" std_smear = jnp.std(smear_modulation)\n",
205+
" vis = jnp.sum((g1 * g2.conj()) * (A * (fringe * smear_modulation)), axis=1).astype(jnp.complex64) # [B]\n",
206+
" else:\n",
207+
" mean_time_smear = std_time_smear = mean_freq_smear = std_freq_smear = mean_smear = std_smear = jnp.asarray(\n",
208+
" 1., jnp.float32)\n",
209+
" vis = jnp.sum((g1 * g2.conj()) * (A * fringe), axis=1).astype(jnp.complex64) # [B]\n",
210+
" key1, key2 = jax.random.split(key)\n",
211+
" # divide by sqrt(2) for real and imag part\n",
212+
" noise = (baseline_noise / np.sqrt(2)) * jax.lax.complex(\n",
213+
" jax.random.normal(key1, shape=vis.shape, dtype=vis.real.dtype),\n",
214+
" jax.random.normal(key2, shape=vis.shape, dtype=vis.imag.dtype)).astype(\n",
215+
" vis.dtype)\n",
216+
" vis_noise = vis + noise\n",
217+
" # compute image, normalising\n",
218+
" fringe = jax.lax.complex(jnp.cos(coeff * phase_eval), jnp.sin(coeff * phase_eval)).astype(jnp.complex64)\n",
219+
" image = n.astype(jnp.float32) * jnp.sum((vis * fringe).real, axis=1) # [M]\n",
220+
" image_noise = n.astype(jnp.float32) * jnp.sum((vis_noise * fringe).real, axis=1) # [M]\n",
163221
" image /= vis.size\n",
164-
" image -= zero_point\n",
165-
" image_sum += image\n",
166-
" return image_sum.astype(jnp.float32), None\n",
222+
" image_noise /= vis.size\n",
223+
" delta = (\n",
224+
" image, image_noise,\n",
225+
" mean_time_smear, std_time_smear, mean_freq_smear, std_freq_smear, mean_smear, std_smear\n",
226+
" )\n",
227+
" accumulate = jax.tree.map(lambda x, y: jax.lax.add(x.astype(jnp.float32), y.astype(jnp.float32)), accumulate,\n",
228+
" delta)\n",
229+
" return accumulate, None\n",
230+
"\n",
231+
" accumulate = (\n",
232+
" jnp.zeros(l.shape, jnp.float32), jnp.zeros(l.shape, jnp.float32),\n",
233+
" jnp.zeros(1, jnp.float32), jnp.zeros(1, jnp.float32),\n",
234+
" jnp.zeros(1, jnp.float32), jnp.zeros(1, jnp.float32),\n",
235+
" jnp.zeros(1, jnp.float32), jnp.zeros(1, jnp.float32)\n",
236+
" )\n",
237+
" accuulate, _ = jax.lax.scan(\n",
238+
" accumulate_over_freq,\n",
239+
" accumulate,\n",
240+
" (freqs, jax.random.split(key, len(freqs)))\n",
241+
" )\n",
167242
"\n",
168-
" image_sum, _ = jax.lax.scan(add_square_sum, jnp.zeros(l.shape, jnp.float32), (freqs,))\n",
169-
" image_sum /= freqs.size\n",
170-
" rms = jnp.sqrt(jnp.sum(image_sum ** 2))\n",
171-
" return rms\n",
243+
" # Normalize by number of freqs\n",
244+
" accumulate = jax.tree.map(lambda x: x / len(freqs), accumulate)\n",
245+
" (\n",
246+
" image, image_noise,\n",
247+
" mean_time_smear, std_time_smear, mean_freq_smear, std_freq_smear, mean_smear, std_smear\n",
248+
" ) = accumulate\n",
249+
"\n",
250+
" # Compute RMS and image normal stats\n",
251+
" rms_no_noise = jnp.sqrt(jnp.sum((image - zero_point) ** 2))\n",
252+
" max_no_noise = jnp.max(image)\n",
253+
" min_no_noise = jnp.min(image)\n",
254+
" mean_no_noise = jnp.mean(image)\n",
255+
" std_no_noise = jnp.std(image)\n",
256+
"\n",
257+
" rms_noise = jnp.sqrt(jnp.sum((image_noise - zero_point) ** 2))\n",
258+
" max_noise = jnp.max(image_noise)\n",
259+
" min_noise = jnp.min(image_noise)\n",
260+
" mean_noise = jnp.mean(image_noise)\n",
261+
" std_noise = jnp.std(image_noise)\n",
262+
"\n",
263+
" return ValuesAndRMS(\n",
264+
" rms_no_noise=rms_no_noise,\n",
265+
" max_no_noise=max_no_noise,\n",
266+
" min_no_noise=min_no_noise,\n",
267+
" mean_no_noise=mean_no_noise,\n",
268+
" std_no_noise=std_no_noise,\n",
269+
" rms_noise=rms_noise,\n",
270+
" max_noise=max_noise,\n",
271+
" min_noise=min_noise,\n",
272+
" mean_noise=mean_noise,\n",
273+
" std_noise=std_noise,\n",
274+
" mean_abs_R=mean_abs_R,\n",
275+
" std_abs_R=std_abs_R,\n",
276+
" mean_time_smear=mean_time_smear,\n",
277+
" std_time_smear=std_time_smear,\n",
278+
" mean_freq_smear=mean_freq_smear,\n",
279+
" std_freq_smear=std_freq_smear,\n",
280+
" mean_smear=mean_smear,\n",
281+
" std_smear=std_smear\n",
282+
" )\n",
172283
"\n",
173284
"\n",
174285
"@jax.jit\n",
@@ -222,7 +333,8 @@
222333
" dawn: bool,\n",
223334
" high_sun_spot: bool,\n",
224335
" with_ionosphere: bool = False,\n",
225-
" with_dish_effects: bool = False\n",
336+
" with_dish_effects: bool = False,\n",
337+
" with_smearing: bool = True\n",
226338
"):\n",
227339
" plt.close('all')\n",
228340
" t0 = time.time()\n",
@@ -245,15 +357,21 @@
245357
" freqs_jax = quantity_to_jnp(freqs)\n",
246358
" times_jax = time_to_jnp(times, ref_time)\n",
247359
"\n",
248-
" thermal_floor = np.asarray(calc_image_noise(\n",
360+
" thermal_noise = float(calc_image_noise(\n",
249361
" system_equivalent_flux_density=quantity_to_jnp(array.get_system_equivalent_flux_density(), 'Jy'),\n",
250362
" bandwidth_hz=quantity_to_jnp(array.get_channel_width()) * len(freqs),\n",
251-
" t_int_s=10.3 * 60.,\n",
363+
" t_int_s=quantity_to_jnp(array.get_integration_time(), 's'),\n",
252364
" num_antennas=len(antennas),\n",
253365
" flag_frac=0.33,\n",
254366
" num_pol=2\n",
255367
" )) * au.Jy\n",
256368
"\n",
369+
" baseline_noise = float(calc_baseline_noise(\n",
370+
" system_equivalent_flux_density=quantity_to_jnp(array.get_system_equivalent_flux_density(), 'Jy'),\n",
371+
" chan_width_hz=quantity_to_jnp(array.get_channel_width(), 'Hz'),\n",
372+
" t_int_s=quantity_to_jnp(array.get_integration_time(), 's')\n",
373+
" ) / np.sqrt(2)) * au.Jy # assume stokes I so 2 cross pols combined reduces noise by sqrt(2)\n",
374+
"\n",
257375
" far_field_delay_engine = build_far_field_delay_engine(\n",
258376
" antennas=antennas,\n",
259377
" phase_center=phase_center,\n",
@@ -387,15 +505,15 @@
387505
" f.write(json.dumps(beam_amp.tolist()))\n",
388506
"\n",
389507
" # A * beam^2 * psf > sigma => A > 1muJy / beam^2 / psf\n",
390-
" global_flux_cut = thermal_floor / (global_crest_peak ** 2 * prior_psf_sidelobe_peak)\n",
391-
" print(f\"Thermal floor: {thermal_floor}\")\n",
508+
" global_flux_cut = thermal_noise / (global_crest_peak ** 2 * prior_psf_sidelobe_peak)\n",
509+
" print(f\"Thermal floor: {thermal_noise}\")\n",
392510
" print(f\"Global crest peak outside {angular_radius}: {global_crest_peak}\")\n",
393511
" print(f\"PSF sidelobe peak: {prior_psf_sidelobe_peak}\")\n",
394512
" print(f\"==> Flux cut: {global_flux_cut}\")\n",
395513
" select_mask = jnp.any(bright_sky_model.A > quantity_to_jnp(global_flux_cut, 'Jy'), axis=1) # [N]\n",
396514
" print(f\"Global: {np.sum(select_mask)} selected brightest sources out of {len(bright_sky_model.A)}\")\n",
397515
"\n",
398-
" flux_cut = thermal_floor / (beam_amp ** 2 * prior_psf_sidelobe_peak) # [N, F]\n",
516+
" flux_cut = thermal_noise / (beam_amp ** 2 * prior_psf_sidelobe_peak) # [N, F]\n",
399517
" select_mask = jnp.any(bright_sky_model.A > flux_cut, axis=1) # [N]\n",
400518
" print(f\"Mean beam amp: {jnp.mean(beam_amp)}\")\n",
401519
" print(f\"Mean flux cut: {jnp.mean(flux_cut[select_mask])}\")\n",
@@ -512,8 +630,10 @@
512630
" # DFT vis only M directions\n",
513631
" # Compute RMS, with zero-point adjustment +1/(N-1) or not\n",
514632
" zero_point = 0. #- 1 / (len(antennas) - 1)\n",
515-
" rms = jax.block_until_ready(\n",
516-
" compute_sky_values(\n",
633+
" key, sample_key = jax.random.split(key)\n",
634+
" values = jax.block_until_ready(\n",
635+
" compute_rms_and_values(\n",
636+
" key=sample_key,\n",
517637
" l=l, m=m, n=n,\n",
518638
" bright_sky_model=bright_sky_model,\n",
519639
" total_gain_model=total_gain_model,\n",
@@ -522,12 +642,15 @@
522642
" geodesic_model=geodesic_model,\n",
523643
" freqs=freqs_jax,\n",
524644
" zero_point=zero_point,\n",
525-
" integration_time=quantity_to_jnp(array.get_integration_time()),\n",
526-
" channel_width=quantity_to_jnp(array.get_channel_width()),\n",
645+
" integration_time=quantity_to_jnp(array.get_integration_time(), 's'),\n",
646+
" channel_width=quantity_to_jnp(array.get_channel_width(), 'Hz'),\n",
527647
" ra0=phase_center.ra.rad,\n",
528-
" dec0=phase_center.dec.rad\n",
648+
" dec0=phase_center.dec.rad,\n",
649+
" baseline_noise=quantity_to_jnp(baseline_noise, 'Jy'),\n",
650+
" smearing=with_smearing,\n",
529651
" )\n",
530652
" )\n",
653+
" result_values = jax.tree.map(float, values)\n",
531654
" t1 = time.time()\n",
532655
" result = Result(\n",
533656
" phase_center=phase_center,\n",
@@ -545,8 +668,11 @@
545668
" high_sun_spot=high_sun_spot,\n",
546669
" with_ionosphere=with_ionosphere,\n",
547670
" with_dish_effects=with_dish_effects,\n",
671+
" with_smearing=with_smearing,\n",
548672
" run_time=float(t1 - t0),\n",
549-
" rms=float(rms) * au.Jy\n",
673+
" thermal_noise=thermal_noise,\n",
674+
" baseline_noise=baseline_noise,\n",
675+
" result_values=result_values._asdict()\n",
550676
" )\n",
551677
"\n",
552678
" with open(os.path.join(save_folder, f'result_{result_num:03d}.json'), 'w') as f:\n",
@@ -569,31 +695,36 @@
569695
" high_sun_spot: bool\n",
570696
" with_ionosphere: bool\n",
571697
" with_dish_effects: bool\n",
698+
" with_smearing: bool\n",
572699
" run_time: float\n",
573-
" rms: au.Quantity\n",
700+
" thermal_noise: au.Quantity\n",
701+
" baseline_noise: au.Quantity\n",
702+
" result_values: Dict[str, float]\n",
574703
"\n",
575704
"\n",
576705
"for pointing_offset_stddev in [0, 1, 2, 4] * au.arcmin:\n",
577706
" for axial_focus_error_stddev in [0, 3, 5] * au.mm:\n",
578707
" for horizon_peak_astigmatism_stddev in [0, 1, 2, 4] * au.mm:\n",
579-
" main(\n",
580-
" seed=0,\n",
581-
" save_folder='sky_loss_11Mar2025_varying_systematics_smearing',\n",
582-
" array_name='dsa2000_optimal_v1',\n",
583-
" pointing=ac.ICRS(0 * au.deg, 0 * au.deg),\n",
584-
" num_measure_points=256,\n",
585-
" angular_radius=1.75 * au.deg,\n",
586-
" prior_psf_sidelobe_peak=1e-3,\n",
587-
" bright_source_id='nvss_calibrators',\n",
588-
" pointing_offset_stddev=pointing_offset_stddev,\n",
589-
" axial_focus_error_stddev=axial_focus_error_stddev,\n",
590-
" horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,\n",
591-
" turbulent=True,\n",
592-
" dawn=True,\n",
593-
" high_sun_spot=True,\n",
594-
" with_ionosphere=True,\n",
595-
" with_dish_effects=True\n",
596-
" )\n",
708+
" for with_smearing in [True, False]:\n",
709+
" main(\n",
710+
" seed=0,\n",
711+
" save_folder='sky_loss_11Mar2025_varying_systematics_more_stats',\n",
712+
" array_name='dsa2000_optimal_v1',\n",
713+
" pointing=ac.ICRS(0 * au.deg, 0 * au.deg),\n",
714+
" num_measure_points=256,\n",
715+
" angular_radius=1.75 * au.deg,\n",
716+
" prior_psf_sidelobe_peak=1e-3,\n",
717+
" bright_source_id='nvss_calibrators',\n",
718+
" pointing_offset_stddev=pointing_offset_stddev,\n",
719+
" axial_focus_error_stddev=axial_focus_error_stddev,\n",
720+
" horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,\n",
721+
" turbulent=True,\n",
722+
" dawn=True,\n",
723+
" high_sun_spot=True,\n",
724+
" with_ionosphere=True,\n",
725+
" with_dish_effects=True,\n",
726+
" with_smearing=with_smearing\n",
727+
" )\n",
597728
"\n",
598729
"fill_registries()\n",
599730
"survey_pointings = misc_registry.get_instance(misc_registry.get_match('survey_pointings'))\n",
@@ -602,7 +733,7 @@
602733
" print(pointing)\n",
603734
" main(\n",
604735
" seed=0,\n",
605-
" save_folder='sky_loss_11Mar2025_full_survey_smearing',\n",
736+
" save_folder='sky_loss_11Mar2025_full_survey_more_stats',\n",
606737
" array_name='dsa2000_optimal_v1',\n",
607738
" pointing=pointing,\n",
608739
" num_measure_points=256,\n",
@@ -616,7 +747,8 @@
616747
" dawn=True,\n",
617748
" high_sun_spot=True,\n",
618749
" with_ionosphere=True,\n",
619-
" with_dish_effects=True\n",
750+
" with_dish_effects=True,\n",
751+
" with_smearing=True\n",
620752
" )\n",
621753
"\n",
622754
"\n",

0 commit comments

Comments
 (0)