Skip to content

Commit fa87c76

Browse files
committed
Merge branch 'joshs-working-branch' into develop
2 parents 2767797 + 17628c1 commit fa87c76

23 files changed

+2043
-682
lines changed

dsa2000_cal/benchmarking/calibration/benchmark_JtJg_calculation_TBC_cpu.py

Lines changed: 24 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import os
22
from functools import partial
33

4-
from dsa2000_common.common.fit_benchmark import fit_timings
5-
64
os.environ['JAX_PLATFORMS'] = 'cpu'
7-
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
85
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
96

107
from jax._src.partition_spec import PartitionSpec
118
from jax.experimental.shard_map import shard_map
9+
from dsa2000_common.common.fit_benchmark import fit_timings
1210

1311
from dsa2000_common.common.jax_utils import create_mesh
1412
from dsa2000_common.common.jvp_linear_op import JVPLinearOp
@@ -24,7 +22,10 @@
2422
from dsa2000_cal.ops.residuals import compute_residual_TBC
2523

2624

27-
def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
25+
def prepare_data(D: int, T, C, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
26+
assert T % Tm == 0 and T % Ts == 0
27+
assert C % Cm == 0 and C % Cs == 0
28+
2829
num_antennas = 2048
2930
baseline_pairs = np.asarray(list(itertools.combinations(range(num_antennas), 2)),
3031
dtype=np.int32)
@@ -36,9 +37,10 @@ def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
3637
antenna2 = antenna2[sort_idxs]
3738

3839
B = antenna1.shape[0]
39-
vis_model = np.zeros((D, Tm, B, Cm, 2, 2), dtype=mp_policy.vis_dtype)
40-
vis_data = np.zeros((Ts, B, Cs, 2, 2), dtype=mp_policy.vis_dtype)
41-
gains = np.zeros((D, Tm, num_antennas, Cm, 2, 2), dtype=mp_policy.gain_dtype)
40+
vis_model = np.zeros((D, T // Tm, B, C // Cm, 2, 2), dtype=mp_policy.vis_dtype)
41+
vis_data = np.zeros((T // Tm, B, C // Cm, 2, 2), dtype=mp_policy.vis_dtype)
42+
gains = np.zeros((D, T // Ts, num_antennas, C // Cs, 2, 2), dtype=mp_policy.gain_dtype)
43+
4244
return dict(
4345
vis_model=vis_model,
4446
vis_data=vis_data,
@@ -86,84 +88,35 @@ def entry_point_sharded(local_data):
8688
return entry_point_sharded, mesh
8789

8890

89-
def main():
90-
cpus = jax.devices("cpu")
91-
# gpus = jax.devices("cuda")
92-
cpu = cpus[0]
93-
# gpu = gpus[0]
94-
95-
entry_point_jit = jax.jit(entry_point)
96-
sharded_entry_point, mesh = build_sharded_entry_point(cpus)
91+
def run(T, C, Ts, Tm, Cs, Cm, backend, m=10, task='J^T.J.g(x)'):
92+
devices = jax.devices(backend)
93+
sharded_entry_point, mesh = build_sharded_entry_point(devices)
9794
sharded_entry_point_jit = jax.jit(sharded_entry_point)
9895
# Run benchmarking over number of calibration directions
99-
time_array = []
10096
shard_time_array = []
10197
d_array = []
10298
for D in range(1, 9):
103-
data = prepare_data(D, Ts=1, Tm=1, Cs=1, Cm=1)
104-
with jax.default_device(cpu):
105-
data = jax.device_put(data)
106-
entry_point_jit_compiled = entry_point_jit.lower(data).compile()
107-
t0 = time.time()
108-
for _ in range(3):
109-
jax.block_until_ready(entry_point_jit_compiled(data))
110-
t1 = time.time()
111-
dt = (t1 - t0) / 3
112-
dsa_logger.info(f"TBC: Residual: CPU D={D}: {dt}")
113-
time_array.append(dt)
114-
d_array.append(D)
115-
99+
data = prepare_data(D, T, C, Ts, Tm, Cs, Cm)
116100
sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
117101
t0 = time.time()
118-
for _ in range(3):
102+
for _ in range(m):
119103
jax.block_until_ready(sharded_entry_point_jit_compiled(data))
120104
t1 = time.time()
121-
dt = (t1 - t0) / 3
122-
dsa_logger.info(f"TBC: Residual (sharded): CPU D={D}: {dt}")
105+
dt = (t1 - t0) / m
106+
dsa_logger.info(f"{task}: {backend} D={D}: {dt}")
123107
shard_time_array.append(dt)
124-
#
125-
# data = prepare_data(D, Ts=4, Tm=1, Cs=4, Cm=1)
126-
# with jax.default_device(cpu):
127-
# data = jax.device_put(data)
128-
# entry_point_jit_compiled = entry_point_jit.lower(data).compile()
129-
# t0 = time.time()
130-
# jax.block_until_ready(entry_point_jit_compiled(data))
131-
# t1 = time.time()
132-
# dsa_logger.info(f"TBC: Subtract (per-GPU): CPU D={D}: {t1 - t0}")
133-
#
134-
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
135-
# t0 = time.time()
136-
# for _ in range(1):
137-
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
138-
# t1 = time.time()
139-
# dt = (t1 - t0) / 1
140-
# dsa_logger.info(f"TBC: Subtract (per-GPU sharded): CPU D={D}: {dt}")
141-
#
142-
# data = prepare_data(D, Ts=4, Tm=1, Cs=40, Cm=1)
143-
# with jax.default_device(cpu):
144-
# data = jax.device_put(data)
145-
# entry_point_jit_compiled = entry_point_jit.lower(data).compile()
146-
# t0 = time.time()
147-
# jax.block_until_ready(entry_point_jit_compiled(data))
148-
# t1 = time.time()
149-
# dsa_logger.info(f"TBC: Subtract (all-GPU): CPU D={D}: {t1 - t0}")
150-
#
151-
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
152-
# t0 = time.time()
153-
# for _ in range(1):
154-
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
155-
# t1 = time.time()
156-
# dt = (t1 - t0) / 1
157-
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
158-
159-
time_array = np.array(time_array)
108+
d_array.append(D)
109+
160110
shard_time_array = np.array(shard_time_array)
161111
d_array = np.array(d_array)
162112

163-
a, b, c = fit_timings(d_array, time_array)
164-
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
165113
a, b, c = fit_timings(d_array, shard_time_array)
166-
dsa_logger.info(f"Fit (sharded): t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
114+
dsa_logger.info(f"{task}: {backend}: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
115+
116+
117+
def main():
118+
# run(T=4, C=40, Ts=4, Tm=4, Cs=40, Cm=40, backend='cpu', m=10, task='J^T.J.g(x) [all-GPU]')
119+
run(T=4, C=4, Ts=4, Tm=4, Cs=4, Cm=4, backend='cpu', m=10, task='J^T.J.g(x) [Full.Avg. per-GPU]')
167120

168121

169122
if __name__ == '__main__':

dsa2000_cal/benchmarking/calibration/benchmark_JtJg_calculation_TBC_gpu.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import os
22
from functools import partial
33

4-
from dsa2000_common.common.fit_benchmark import fit_timings
4+
from dsa2000_cal.solvers.cg import tree_scalar_mul, tree_add
55

6-
os.environ['JAX_PLATFORMS'] = 'cuda'
6+
os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
77
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
8+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
89

910
from jax._src.partition_spec import PartitionSpec
1011
from jax.experimental.shard_map import shard_map
12+
from dsa2000_common.common.fit_benchmark import fit_timings
1113

1214
from dsa2000_common.common.jax_utils import create_mesh
1315
from dsa2000_common.common.jvp_linear_op import JVPLinearOp
@@ -16,14 +18,17 @@
1618
import itertools
1719
import time
1820
from typing import Dict, Any
19-
21+
import jax.numpy as jnp
2022
import jax
2123
import numpy as np
2224

2325
from dsa2000_cal.ops.residuals import compute_residual_TBC
2426

2527

26-
def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
28+
def prepare_data(D: int, T, C, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
29+
assert T % Tm == 0 and T % Ts == 0
30+
assert C % Cm == 0 and C % Cs == 0
31+
2732
num_antennas = 2048
2833
baseline_pairs = np.asarray(list(itertools.combinations(range(num_antennas), 2)),
2934
dtype=np.int32)
@@ -35,9 +40,10 @@ def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
3540
antenna2 = antenna2[sort_idxs]
3641

3742
B = antenna1.shape[0]
38-
vis_model = np.zeros((D, Tm, B, Cm, 2, 2), dtype=mp_policy.vis_dtype)
39-
vis_data = np.zeros((Ts, B, Cs, 2, 2), dtype=mp_policy.vis_dtype)
40-
gains = np.zeros((D, Tm, num_antennas, Cm, 2, 2), dtype=mp_policy.gain_dtype)
43+
vis_model = np.zeros((D, T // Tm, B, C // Cm, 2, 2), dtype=mp_policy.vis_dtype)
44+
vis_data = np.zeros((T // Tm, B, C // Cm, 2, 2), dtype=mp_policy.vis_dtype)
45+
gains = np.zeros((D, T // Ts, num_antennas, C // Cs, 2, 2), dtype=mp_policy.gain_dtype)
46+
4147
return dict(
4248
vis_model=vis_model,
4349
vis_data=vis_data,
@@ -63,7 +69,10 @@ def fn(gains):
6369
J = J_bare(gains)
6470
R = fn(gains)
6571
g = J.matvec(R, adjoint=True)
66-
return J.matvec(J.matvec(g), adjoint=True)
72+
p = jax.tree.map(jnp.ones_like, g)
73+
JTJv = J.matvec(J.matvec(p), adjoint=True)
74+
damping = jnp.asarray(1.)
75+
return tree_add(JTJv, tree_scalar_mul(damping, p))
6776

6877

6978
def build_sharded_entry_point(devices):
@@ -85,51 +94,37 @@ def entry_point_sharded(local_data):
8594
return entry_point_sharded, mesh
8695

8796

88-
def main():
89-
gpus = jax.devices("cuda")
90-
91-
sharded_entry_point, mesh = build_sharded_entry_point(gpus)
97+
def run(T, C, Ts, Tm, Cs, Cm, backend, m=10, task='J^T.J.g(x)'):
98+
devices = jax.devices(backend)
99+
sharded_entry_point, mesh = build_sharded_entry_point(devices)
92100
sharded_entry_point_jit = jax.jit(sharded_entry_point)
93101
# Run benchmarking over number of calibration directions
94102
shard_time_array = []
95103
d_array = []
96104
for D in range(1, 9):
97-
data = prepare_data(D, Ts=1, Tm=1, Cs=1, Cm=1)
98-
105+
data = prepare_data(D, T, C, Ts, Tm, Cs, Cm)
99106
sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
100107
t0 = time.time()
101-
for _ in range(10):
108+
for _ in range(m):
102109
jax.block_until_ready(sharded_entry_point_jit_compiled(data))
103110
t1 = time.time()
104-
dt = (t1 - t0) / 10
105-
dsa_logger.info(f"TBC: J^T.J.g (Full avg.): CPU D={D}: {dt}")
106-
d_array.append(D)
111+
dt = (t1 - t0) / m
112+
dsa_logger.info(f"{task}: {backend} D={D}: {dt}")
107113
shard_time_array.append(dt)
108-
109-
data = prepare_data(D, Ts=4, Tm=1, Cs=4, Cm=1)
110-
111-
sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
112-
t0 = time.time()
113-
for _ in range(10):
114-
jax.block_until_ready(sharded_entry_point_jit_compiled(data))
115-
t1 = time.time()
116-
dt = (t1 - t0) / 10
117-
dsa_logger.info(f"TBC: J^T.J.g (per-GPU w/ reps): CPU D={D}: {dt}")
118-
119-
# data = prepare_data(D, Ts=4, Tm=1, Cs=40, Cm=1)
120-
# sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
121-
# t0 = time.time()
122-
# for _ in range(1):
123-
# jax.block_until_ready(sharded_entry_point_jit_compiled(data))
124-
# t1 = time.time()
125-
# dt = (t1 - t0) / 1
126-
# dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
114+
d_array.append(D)
127115

128116
shard_time_array = np.array(shard_time_array)
129117
d_array = np.array(d_array)
130118

131119
a, b, c = fit_timings(d_array, shard_time_array)
132-
dsa_logger.info(f"Fit: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
120+
dsa_logger.info(f"{task}: {backend}: t(n) = {a:.4f} * n ** {b:.2f} + {c:.4f}")
121+
122+
123+
def main():
124+
# run(T=4, C=40, Ts=4, Tm=4, Cs=40, Cm=40, backend='cuda', m=10, task='J^T.J.g(x) [all-GPU]')
125+
# run(T=4, C=40, Ts=4, Tm=4, Cs=40, Cm=40, backend='cpu', m=10, task='J^T.J.g(x) [all-GPU]')
126+
run(T=4, C=4, Ts=4, Tm=4, Cs=4, Cm=4, backend='cuda', m=10, task='J^T.J.g(x) [Full.Avg. per-GPU]')
127+
run(T=4, C=4, Ts=4, Tm=4, Cs=4, Cm=4, backend='cpu', m=10, task='J^T.J.g(x) [Full.Avg. per-GPU]')
133128

134129

135130
if __name__ == '__main__':

0 commit comments

Comments
 (0)