Skip to content

Commit 52c1121

Browse files
committed
Merge branch 'develop'
2 parents 67e21b0 + 18c0ff6 commit 52c1121

File tree

109 files changed

+7398
-1710
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+7398
-1710
lines changed

dsa2000_cal/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ cd DSA2000-Cal
6161

6262
```bash
6363
conda activate dsa2000_cal_py
64-
pip install dsa2000_call
64+
pip install dsa2000_cal
6565
```
6666

67-
8. Set up PyCharm for development
67+
8. Set up PyCharm for development (optional but recommended).
6868

6969
1. Make sure you have created a `dsa2000_cal_py` conda env as above, and installed requirements files.
7070
2. Create a new project in PyCharm in the repo root directory `/home/username/git/DSA2000-Cal`. Use an empty project

dsa2000_cal/benchmarking/DFT_predict/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import astropy.units as au
99
import jax
1010
import numpy as np
11-
from tomographic_kernel.frames import ENU
11+
from dsa2000_common.common.enu_frame import ENU
1212

1313
from dsa2000_common.common.quantity_utils import quantity_to_jnp, time_to_jnp
1414
from dsa2000_common.delay_models.base_far_field_delay_engine import build_far_field_delay_engine

dsa2000_cal/benchmarking/calibration/main.py renamed to dsa2000_cal/benchmarking/actual_calibration/main.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
from dsa2000_common.common.ray_utils import TimerLog
34
from dsa2000_common.visibility_model.source_models.celestial.base_point_source_model import BasePointSourceModel
45

56
os.environ['JAX_PLATFORMS'] = 'cuda'
@@ -12,9 +13,7 @@
1213

1314
# os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"
1415

15-
import dataclasses
1616
import os
17-
import time
1817
from functools import partial
1918
from typing import Generator, Tuple, List
2019

@@ -31,7 +30,7 @@
3130
import tensorflow_probability.substrates.jax as tfp
3231
from astropy import units as au, coordinates as ac
3332
from matplotlib import pyplot as plt
34-
from tomographic_kernel.frames import ENU
33+
from dsa2000_common.common.enu_frame import ENU
3534

3635
from dsa2000_assets.content_registry import fill_registries
3736
from dsa2000_assets.registries import array_registry
@@ -65,21 +64,6 @@
6564
tfpd = tfp.distributions
6665

6766

68-
@dataclasses.dataclass
69-
class TimerLog:
70-
msg: str
71-
72-
def __post_init__(self):
73-
self.t0 = time.time()
74-
75-
def __enter__(self):
76-
print(f"{self.msg}")
77-
78-
def __exit__(self, exc_type, exc_val, exc_tb):
79-
print(f"... took {time.time() - self.t0:.3f} seconds")
80-
return False
81-
82-
8367
def build_mock_obs_setup(array_name: str, num_sol_ints_time: int, frac_aperture: float = 1.):
8468
fill_registries()
8569
array = array_registry.get_instance(array_registry.get_match(array_name))
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#!/bin/bash
2+
3+
conda create -n cal_benchmark python=3.11
4+
conda activate cal_benchmark
5+
6+
pip install jax[cuda12] jaxlib 'numpy<2' nvtx
7+
8+
python standalone_lm_multi_step.py
9+
10+
conda deactivate
11+
conda remove -n cal_benchmark --all

0 commit comments

Comments
 (0)