|
42 | 42 | "source": [ |
43 | 43 | "\n", |
44 | 44 | "\n", |
45 | | - "import glob\n", |
46 | 45 | "import os\n", |
| 46 | + "import queue\n", |
47 | 47 | "import time\n", |
48 | 48 | "from typing import NamedTuple, Dict\n", |
49 | 49 | "\n", |
|
318 | 318 | "\n", |
319 | 319 | "\n", |
320 | 320 | "def main(\n", |
| 321 | + " cpu,\n", |
| 322 | + " gpu,\n", |
| 323 | + " result_num: int,\n", |
321 | 324 | " seed: int,\n", |
322 | 325 | " save_folder: str,\n", |
323 | 326 | " array_name: str,\n", |
|
338 | 341 | "):\n", |
339 | 342 | " plt.close('all')\n", |
340 | 343 | " t0 = time.time()\n", |
341 | | - " result_num = len(glob.glob(os.path.join(save_folder, 'result_*.json')))\n", |
342 | | - " cpu = jax.devices(\"cpu\")[0]\n", |
343 | | - " gpu = jax.devices(\"cuda\")[0] # or \"gpu\" depending on platform\n", |
344 | 344 | "\n", |
345 | 345 | " key = jax.random.PRNGKey(seed)\n", |
346 | 346 | " fill_registries()\n", |
|
702 | 702 | " result_values: Dict[str, float]\n", |
703 | 703 | "\n", |
704 | 704 | "\n", |
| 705 | + "gpu = jax.devices(\"cuda\")[0] # or \"gpu\" depending on platform\n", |
| 706 | + "\n", |
| 707 | + "\n", |
| 708 | + "def run_varying_systematics(result_idx, cpu, gpu, pointing_offset_stddev, axial_focus_error_stddev,\n", |
| 709 | + " horizon_peak_astigmatism_stddev, with_smearing):\n", |
| 710 | + " main(\n", |
| 711 | + " cpu=cpu,\n", |
| 712 | + " gpu=gpu,\n", |
| 713 | + " result_num=result_idx,\n", |
| 714 | + " seed=0,\n", |
| 715 | + " save_folder='sky_loss_11Mar2025_varying_systematics_more_stats',\n", |
| 716 | + " array_name='dsa2000_optimal_v1',\n", |
| 717 | + " pointing=ac.ICRS(0 * au.deg, 0 * au.deg),\n", |
| 718 | + " num_measure_points=256,\n", |
| 719 | + " angular_radius=1.75 * au.deg,\n", |
| 720 | + " prior_psf_sidelobe_peak=1e-3,\n", |
| 721 | + " bright_source_id='nvss_calibrators',\n", |
| 722 | + " pointing_offset_stddev=pointing_offset_stddev,\n", |
| 723 | + " axial_focus_error_stddev=axial_focus_error_stddev,\n", |
| 724 | + " horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,\n", |
| 725 | + " turbulent=True,\n", |
| 726 | + " dawn=True,\n", |
| 727 | + " high_sun_spot=True,\n", |
| 728 | + " with_ionosphere=True,\n", |
| 729 | + " with_dish_effects=True,\n", |
| 730 | + " with_smearing=with_smearing\n", |
| 731 | + " )\n", |
| 732 | + "\n", |
| 733 | + "\n", |
| 734 | + "cpus = jax.devices(\"cpu\")\n", |
| 735 | + "gpus = jax.devices(\"cuda\")\n", |
| 736 | + "queues = [queue.Queue() for _ in gpus]\n", |
| 737 | + "\n", |
| 738 | + "# fill queues with input args\n", |
| 739 | + "result_idx = 0\n", |
705 | 740 | "for pointing_offset_stddev in [0, 1, 2, 4] * au.arcmin:\n", |
706 | 741 | " for axial_focus_error_stddev in [0, 3, 5] * au.mm:\n", |
707 | 742 | " for horizon_peak_astigmatism_stddev in [0, 1, 2, 4] * au.mm:\n", |
708 | 743 | " for with_smearing in [True, False]:\n", |
709 | | - " main(\n", |
710 | | - " seed=0,\n", |
711 | | - " save_folder='sky_loss_11Mar2025_varying_systematics_more_stats',\n", |
712 | | - " array_name='dsa2000_optimal_v1',\n", |
713 | | - " pointing=ac.ICRS(0 * au.deg, 0 * au.deg),\n", |
714 | | - " num_measure_points=256,\n", |
715 | | - " angular_radius=1.75 * au.deg,\n", |
716 | | - " prior_psf_sidelobe_peak=1e-3,\n", |
717 | | - " bright_source_id='nvss_calibrators',\n", |
718 | | - " pointing_offset_stddev=pointing_offset_stddev,\n", |
719 | | - " axial_focus_error_stddev=axial_focus_error_stddev,\n", |
720 | | - " horizon_peak_astigmatism_stddev=horizon_peak_astigmatism_stddev,\n", |
721 | | - " turbulent=True,\n", |
722 | | - " dawn=True,\n", |
723 | | - " high_sun_spot=True,\n", |
724 | | - " with_ionosphere=True,\n", |
725 | | - " with_dish_effects=True,\n", |
726 | | - " with_smearing=with_smearing\n", |
727 | | - " )\n", |
| 744 | + " queue = queues[result_idx % len(gpus)]\n", |
| 745 | + " gpu = gpus[result_idx % len(gpus)]\n", |
| 746 | + " cpu = cpus[result_idx % len(cpus)]\n", |
| 747 | + " queue.put((run_varying_systematics, result_idx, cpu, gpu, pointing_offset_stddev,\n", |
| 748 | + " axial_focus_error_stddev, horizon_peak_astigmatism_stddev, with_smearing))\n", |
728 | 749 | "\n", |
729 | | - "fill_registries()\n", |
730 | | - "survey_pointings = misc_registry.get_instance(misc_registry.get_match('survey_pointings'))\n", |
731 | | - "pointings = survey_pointings.survey_pointings_v1()\n", |
732 | | - "for pointing in tqdm(pointings):\n", |
733 | | - " print(pointing)\n", |
| 750 | + "\n", |
| 751 | + "def run_survey(result_idx, cpu, gpu):\n", |
734 | 752 | " main(\n", |
| 753 | + " cpu=cpu,\n", |
| 754 | + " gpu=gpu,\n", |
| 755 | + " result_num=result_idx,\n", |
735 | 756 | " seed=0,\n", |
736 | 757 | " save_folder='sky_loss_11Mar2025_full_survey_more_stats',\n", |
737 | 758 | " array_name='dsa2000_optimal_v1',\n", |
|
752 | 773 | " )\n", |
753 | 774 | "\n", |
754 | 775 | "\n", |
| 776 | + "fill_registries()\n", |
| 777 | + "survey_pointings = misc_registry.get_instance(misc_registry.get_match('survey_pointings'))\n", |
| 778 | + "pointings = survey_pointings.survey_pointings_v1()\n", |
| 779 | + "\n", |
| 780 | + "# fill queues with input args\n", |
| 781 | + "result_idx = 0\n", |
| 782 | + "for pointing in tqdm(pointings):\n", |
| 783 | + " queue = queues[result_idx % len(gpus)]\n", |
| 784 | + " gpu = gpus[result_idx % len(gpus)]\n", |
| 785 | + " cpu = cpus[result_idx % len(cpus)]\n", |
| 786 | + " queue.put((run_survey, result_idx, cpu, gpu))\n", |
| 787 | + " result_idx += 1\n", |
| 788 | + "\n", |
| 789 | + "\n", |
| 790 | + "# now run the jobs in thread pool\n", |
| 791 | + "def worker(queue):\n", |
| 792 | + " while True:\n", |
| 793 | + " args = queue.get()\n", |
| 794 | + " if args is None:\n", |
| 795 | + " break\n", |
| 796 | + " f = args[0]\n", |
| 797 | + " args = args[1:]\n", |
| 798 | + " f(*args)\n", |
| 799 | + "\n", |
| 800 | + "\n", |
| 801 | + "# now run the jobs in thread pool, each job processes a queue\n", |
| 802 | + "import concurrent.futures\n", |
| 803 | + "\n", |
| 804 | + "with concurrent.futures.ThreadPoolExecutor() as executor:\n", |
| 805 | + " for queue in queues:\n", |
| 806 | + " executor.submit(worker, queue)\n", |
| 807 | + "\n", |
| 808 | + "\n", |
755 | 809 | "\n", |
756 | 810 | "\n" |
757 | 811 | ], |
|
0 commit comments