11import os
22from functools import partial
33
4- from dsa2000_common . common . fit_benchmark import fit_timings
4+ from dsa2000_cal . solvers . cg import tree_scalar_mul , tree_add
55
6- os .environ ['JAX_PLATFORMS' ] = 'cuda'
6+ os .environ ['JAX_PLATFORMS' ] = 'cuda,cpu '
77os .environ ['XLA_PYTHON_CLIENT_MEM_FRACTION' ] = '1.0'
8+ os .environ ["XLA_FLAGS" ] = f"--xla_force_host_platform_device_count={ 8 } "
89
910from jax ._src .partition_spec import PartitionSpec
1011from jax .experimental .shard_map import shard_map
12+ from dsa2000_common .common .fit_benchmark import fit_timings
1113
1214from dsa2000_common .common .jax_utils import create_mesh
1315from dsa2000_common .common .jvp_linear_op import JVPLinearOp
1618import itertools
1719import time
1820from typing import Dict , Any
19-
21+ import jax . numpy as jnp
2022import jax
2123import numpy as np
2224
2325from dsa2000_cal .ops .residuals import compute_residual_TBC
2426
2527
26- def prepare_data (D : int , Ts , Tm , Cs , Cm ) -> Dict [str , Any ]:
28+ def prepare_data (D : int , T , C , Ts , Tm , Cs , Cm ) -> Dict [str , Any ]:
29+ assert T % Tm == 0 and T % Ts == 0
30+ assert C % Cm == 0 and C % Cs == 0
31+
2732 num_antennas = 2048
2833 baseline_pairs = np .asarray (list (itertools .combinations (range (num_antennas ), 2 )),
2934 dtype = np .int32 )
@@ -35,9 +40,10 @@ def prepare_data(D: int, Ts, Tm, Cs, Cm) -> Dict[str, Any]:
3540 antenna2 = antenna2 [sort_idxs ]
3641
3742 B = antenna1 .shape [0 ]
38- vis_model = np .zeros ((D , Tm , B , Cm , 2 , 2 ), dtype = mp_policy .vis_dtype )
39- vis_data = np .zeros ((Ts , B , Cs , 2 , 2 ), dtype = mp_policy .vis_dtype )
40- gains = np .zeros ((D , Tm , num_antennas , Cm , 2 , 2 ), dtype = mp_policy .gain_dtype )
43+ vis_model = np .zeros ((D , T // Tm , B , C // Cm , 2 , 2 ), dtype = mp_policy .vis_dtype )
44+ vis_data = np .zeros ((T // Tm , B , C // Cm , 2 , 2 ), dtype = mp_policy .vis_dtype )
45+ gains = np .zeros ((D , T // Ts , num_antennas , C // Cs , 2 , 2 ), dtype = mp_policy .gain_dtype )
46+
4147 return dict (
4248 vis_model = vis_model ,
4349 vis_data = vis_data ,
@@ -63,7 +69,10 @@ def fn(gains):
6369 J = J_bare (gains )
6470 R = fn (gains )
6571 g = J .matvec (R , adjoint = True )
66- return J .matvec (J .matvec (g ), adjoint = True )
72+ p = jax .tree .map (jnp .ones_like , g )
73+ JTJv = J .matvec (J .matvec (p ), adjoint = True )
74+ damping = jnp .asarray (1. )
75+ return tree_add (JTJv , tree_scalar_mul (damping , p ))
6776
6877
6978def build_sharded_entry_point (devices ):
@@ -85,51 +94,37 @@ def entry_point_sharded(local_data):
8594 return entry_point_sharded , mesh
8695
8796
88- def main ():
89- gpus = jax .devices ("cuda" )
90-
91- sharded_entry_point , mesh = build_sharded_entry_point (gpus )
97+ def run (T , C , Ts , Tm , Cs , Cm , backend , m = 10 , task = 'J^T.J.g(x)' ):
98+ devices = jax .devices (backend )
99+ sharded_entry_point , mesh = build_sharded_entry_point (devices )
92100 sharded_entry_point_jit = jax .jit (sharded_entry_point )
93101 # Run benchmarking over number of calibration directions
94102 shard_time_array = []
95103 d_array = []
96104 for D in range (1 , 9 ):
97- data = prepare_data (D , Ts = 1 , Tm = 1 , Cs = 1 , Cm = 1 )
98-
105+ data = prepare_data (D , T , C , Ts , Tm , Cs , Cm )
99106 sharded_entry_point_jit_compiled = sharded_entry_point_jit .lower (data ).compile ()
100107 t0 = time .time ()
101- for _ in range (10 ):
108+ for _ in range (m ):
102109 jax .block_until_ready (sharded_entry_point_jit_compiled (data ))
103110 t1 = time .time ()
104- dt = (t1 - t0 ) / 10
105- dsa_logger .info (f"TBC: J^T.J.g (Full avg.): CPU D={ D } : { dt } " )
106- d_array .append (D )
111+ dt = (t1 - t0 ) / m
112+ dsa_logger .info (f"{ task } : { backend } D={ D } : { dt } " )
107113 shard_time_array .append (dt )
108-
109- data = prepare_data (D , Ts = 4 , Tm = 1 , Cs = 4 , Cm = 1 )
110-
111- sharded_entry_point_jit_compiled = sharded_entry_point_jit .lower (data ).compile ()
112- t0 = time .time ()
113- for _ in range (10 ):
114- jax .block_until_ready (sharded_entry_point_jit_compiled (data ))
115- t1 = time .time ()
116- dt = (t1 - t0 ) / 10
117- dsa_logger .info (f"TBC: J^T.J.g (per-GPU w/ reps): CPU D={ D } : { dt } " )
118-
119- # data = prepare_data(D, Ts=4, Tm=1, Cs=40, Cm=1)
120- # sharded_entry_point_jit_compiled = sharded_entry_point_jit.lower(data).compile()
121- # t0 = time.time()
122- # for _ in range(1):
123- # jax.block_until_ready(sharded_entry_point_jit_compiled(data))
124- # t1 = time.time()
125- # dt = (t1 - t0) / 1
126- # dsa_logger.info(f"TBC: Subtract (all-GPU sharded): CPU D={D}: {dt}")
114+ d_array .append (D )
127115
128116 shard_time_array = np .array (shard_time_array )
129117 d_array = np .array (d_array )
130118
131119 a , b , c = fit_timings (d_array , shard_time_array )
132- dsa_logger .info (f"Fit: t(n) = { a :.4f} * n ** { b :.2f} + { c :.4f} " )
120+ dsa_logger .info (f"{ task } : { backend } : t(n) = { a :.4f} * n ** { b :.2f} + { c :.4f} " )
121+
122+
123+ def main ():
124+ # run(T=4, C=40, Ts=4, Tm=4, Cs=40, Cm=40, backend='cuda', m=10, task='J^T.J.g(x) [all-GPU]')
125+ # run(T=4, C=40, Ts=4, Tm=4, Cs=40, Cm=40, backend='cpu', m=10, task='J^T.J.g(x) [all-GPU]')
126+ run (T = 4 , C = 4 , Ts = 4 , Tm = 4 , Cs = 4 , Cm = 4 , backend = 'cuda' , m = 10 , task = 'J^T.J.g(x) [Full.Avg. per-GPU]' )
127+ run (T = 4 , C = 4 , Ts = 4 , Tm = 4 , Cs = 4 , Cm = 4 , backend = 'cpu' , m = 10 , task = 'J^T.J.g(x) [Full.Avg. per-GPU]' )
133128
134129
135130if __name__ == '__main__' :
0 commit comments