Skip to content

Commit 42e2db5

Browse files
committed
Merge branch 'develop'
2 parents 52c1121 + f5158d5 commit 42e2db5

30 files changed

+1925
-148
lines changed

dsa2000_cal/benchmarking/actual_calibration/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
build_gaussian_source_model, \
6060
BaseGaussianSourceModel
6161

62-
from dsa2000_fm.forward_models.streaming.average_utils import average_rule
62+
from dsa2000_fm.actors.average_utils import average_rule
6363

6464
tfpd = tfp.distributions
6565

dsa2000_cal/dashboards/dsa/pages/view_calibration_solutions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ray
44
import streamlit as st
55

6-
from dsa2000_fm.forward_models.streaming.calibration_solution_cache import CalibrationSolutionCache, \
6+
from dsa2000_fm.actors.calibration_solution_cache import CalibrationSolutionCache, \
77
CalibrationSolution
88

99

dsa2000_cal/notebooks/explore_calibration.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"\n",
2222
"from dsa2000_common.common.noise import calc_baseline_noise\n",
2323
"from dsa2000_common.common.quantity_utils import time_to_jnp, quantity_to_jnp\n",
24-
"from dsa2000_fm.forward_models.streaming.calibrator import Calibration\n",
24+
"from dsa2000_fm.actors.calibrator import Calibration\n",
2525
"from dsa2000_fm.forward_models.utils import ObservationSetup\n",
2626
"from dsa2000_common.visibility_model.source_models.celestial.base_point_source_model import build_point_source_model\n",
2727
"\n",

dsa2000_cal/notebooks/explore_sky_loss.ipynb

Lines changed: 82 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@
4242
"source": [
4343
"\n",
4444
"\n",
45-
"import glob\n",
4645
"import os\n",
46+
"import queue\n",
4747
"import time\n",
4848
"from typing import NamedTuple, Dict\n",
4949
"\n",
@@ -318,6 +318,9 @@
318318
"\n",
319319
"\n",
320320
"def main(\n",
321+
" cpu,\n",
322+
" gpu,\n",
323+
" result_num: int,\n",
321324
" seed: int,\n",
322325
" save_folder: str,\n",
323326
" array_name: str,\n",
@@ -338,9 +341,6 @@
338341
"):\n",
339342
" plt.close('all')\n",
340343
" t0 = time.time()\n",
341-
" result_num = len(glob.glob(os.path.join(save_folder, 'result_*.json')))\n",
342-
" cpu = jax.devices(\"cpu\")[0]\n",
343-
" gpu = jax.devices(\"cuda\")[0] # or \"gpu\" depending on platform\n",
344344
"\n",
345345
" key = jax.random.PRNGKey(seed)\n",
346346
" fill_registries()\n",
@@ -702,36 +702,57 @@
702702
" result_values: Dict[str, float]\n",
703703
"\n",
704704
"\n",
705+
"gpu = jax.devices(\"cuda\")[0] # or \"gpu\" depending on platform\n",
706+
"\n",
707+
"\n",
708+
"def run_varying_systematics(result_idx, cpu, gpu, pointing_offset_stddev, axial_focus_error_stddev,\n",
709+
" horizon_peak_astigmatism_stddev, with_smearing):\n",
710+
" main(\n",
711+
" cpu=cpu,\n",
712+
" gpu=gpu,\n",
713+
" result_num=result_idx,\n",
714+
" seed=0,\n",
715+
" save_folder='sky_loss_11Mar2025_varying_systematics_more_stats',\n",
716+
" array_name='dsa2000_optimal_v1',\n",
717+
" pointing=ac.ICRS(0 * au.deg, 0 * au.deg),\n",
718+
" num_measure_points=256,\n",
719+
" angular_radius=1.75 * au.deg,\n",
720+
" prior_psf_sidelobe_peak=1e-3,\n",
721+
" bright_source_id='nvss_calibrators',\n",
722+
" pointing_offset_stddev=pointing_offset_stddev,\n",
723+
" axial_focus_error_stddev=axial_focus_error_stddev,\n",
724+
" horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,\n",
725+
" turbulent=True,\n",
726+
" dawn=True,\n",
727+
" high_sun_spot=True,\n",
728+
" with_ionosphere=True,\n",
729+
" with_dish_effects=True,\n",
730+
" with_smearing=with_smearing\n",
731+
" )\n",
732+
"\n",
733+
"\n",
734+
"cpus = jax.devices(\"cpu\")\n",
735+
"gpus = jax.devices(\"cuda\")\n",
736+
"queues = [queue.Queue() for _ in gpus]\n",
737+
"\n",
738+
"# fill queues with input args\n",
739+
"result_idx = 0\n",
705740
"for pointing_offset_stddev in [0, 1, 2, 4] * au.arcmin:\n",
706741
" for axial_focus_error_stddev in [0, 3, 5] * au.mm:\n",
707742
" for horizon_peak_astigmatism_stddev in [0, 1, 2, 4] * au.mm:\n",
708743
" for with_smearing in [True, False]:\n",
709-
" main(\n",
710-
" seed=0,\n",
711-
" save_folder='sky_loss_11Mar2025_varying_systematics_more_stats',\n",
712-
" array_name='dsa2000_optimal_v1',\n",
713-
" pointing=ac.ICRS(0 * au.deg, 0 * au.deg),\n",
714-
" num_measure_points=256,\n",
715-
" angular_radius=1.75 * au.deg,\n",
716-
" prior_psf_sidelobe_peak=1e-3,\n",
717-
" bright_source_id='nvss_calibrators',\n",
718-
" pointing_offset_stddev=pointing_offset_stddev,\n",
719-
" axial_focus_error_stddev=axial_focus_error_stddev,\n",
720-
" horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,\n",
721-
" turbulent=True,\n",
722-
" dawn=True,\n",
723-
" high_sun_spot=True,\n",
724-
" with_ionosphere=True,\n",
725-
" with_dish_effects=True,\n",
726-
" with_smearing=with_smearing\n",
727-
" )\n",
744+
" queue = queues[result_idx % len(gpus)]\n",
745+
" gpu = gpus[result_idx % len(gpus)]\n",
746+
" cpu = cpus[result_idx % len(cpus)]\n",
747+
" queue.put((run_varying_systematics, result_idx, cpu, gpu, pointing_offset_stddev,\n",
748+
" axial_focus_error_stddev, horizon_peak_astigmatism_stddev, with_smearing))\n",
728749
"\n",
729-
"fill_registries()\n",
730-
"survey_pointings = misc_registry.get_instance(misc_registry.get_match('survey_pointings'))\n",
731-
"pointings = survey_pointings.survey_pointings_v1()\n",
732-
"for pointing in tqdm(pointings):\n",
733-
" print(pointing)\n",
750+
"\n",
751+
"def run_survey(result_idx, cpu, gpu):\n",
734752
" main(\n",
753+
" cpu=cpu,\n",
754+
" gpu=gpu,\n",
755+
" result_num=result_idx,\n",
735756
" seed=0,\n",
736757
" save_folder='sky_loss_11Mar2025_full_survey_more_stats',\n",
737758
" array_name='dsa2000_optimal_v1',\n",
@@ -752,6 +773,39 @@
752773
" )\n",
753774
"\n",
754775
"\n",
776+
"fill_registries()\n",
777+
"survey_pointings = misc_registry.get_instance(misc_registry.get_match('survey_pointings'))\n",
778+
"pointings = survey_pointings.survey_pointings_v1()\n",
779+
"\n",
780+
"# fill queues with input args\n",
781+
"result_idx = 0\n",
782+
"for pointing in tqdm(pointings):\n",
783+
" queue = queues[result_idx % len(gpus)]\n",
784+
" gpu = gpus[result_idx % len(gpus)]\n",
785+
" cpu = cpus[result_idx % len(cpus)]\n",
786+
" queue.put((run_survey, result_idx, cpu, gpu))\n",
787+
" result_idx += 1\n",
788+
"\n",
789+
"\n",
790+
"# now run the jobs in thread pool\n",
791+
"def worker(queue):\n",
792+
" while True:\n",
793+
" args = queue.get()\n",
794+
" if args is None:\n",
795+
" break\n",
796+
" f = args[0]\n",
797+
" args = args[1:]\n",
798+
" f(*args)\n",
799+
"\n",
800+
"\n",
801+
"# now run the jobs in thread pool, each job processes a queue\n",
802+
"import concurrent.futures\n",
803+
"\n",
804+
"with concurrent.futures.ThreadPoolExecutor() as executor:\n",
805+
" for queue in queues:\n",
806+
" executor.submit(worker, queue)\n",
807+
"\n",
808+
"\n",
755809
"\n",
756810
"\n"
757811
],

dsa2000_cal/notebooks/performance_test_calibration.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
"from dsa2000_common.visibility_model.source_models.celestial.base_point_source_model import build_point_source_model, \\\n",
7373
" BasePointSourceModel\n",
7474
"\n",
75-
"from dsa2000_fm.forward_models.streaming.average_utils import average_rule\n",
75+
"from dsa2000_fm.actors.average_utils import average_rule\n",
7676
"\n",
7777
"tfpd = tfp.distributions\n",
7878
"\n",

dsa2000_cal/scripts/dd_calibrate_from_casa_ms/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22

3-
from dsa2000_fm.forward_models.streaming.calibrator import Calibration
3+
from dsa2000_fm.actors.calibrator import Calibration
44

55
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
66

0 commit comments

Comments
 (0)