Skip to content

Commit 26cabed

Browse files
committed
Merge branch 'joshs-working-branch' into develop
2 parents fa87c76 + dfe2518 commit 26cabed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+3383
-2154
lines changed

dsa2000_cal/benchmarking/calibration/benchmark_gpu.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import os
22
from functools import partial
33

4+
from jaxlib.xla_extension import XlaRuntimeError
5+
46
os.environ['JAX_PLATFORMS'] = 'cuda,cpu'
57
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'
68
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}"
@@ -127,7 +129,8 @@ def prepare_data(D: int, T, C, BTs, BTm, BCs, BCm) -> Dict[str, Any]:
127129
vis_data = np.zeros((T // BTm, B, C // BCm, 2, 2), dtype=mp_policy.vis_dtype)
128130
gains = np.zeros((D, T // BTs, num_antennas, C // BCs, 2, 2), dtype=mp_policy.gain_dtype)
129131

130-
dsa_logger.info(f"Model size D * {vis_model.nbytes / D / 2 ** 20} MB, gain size D * {gains.nbytes / D / 2 ** 10} KB")
132+
dsa_logger.info(
133+
f"Model size D * {vis_model.nbytes / D / 2 ** 20} MB, gain size D * {gains.nbytes / D / 2 ** 10} KB")
131134

132135
return dict(
133136
vis_model=vis_model,
@@ -278,20 +281,26 @@ def main():
278281
for task, build_sharded_entry_point_fn in zip(
279282
['R(x)', 'J^T.R(x)', 'J^T.J.p'],
280283
[build_sharded_entry_point_R, build_sharded_entry_point_JtR, build_sharded_entry_point_JtJp]):
281-
run(build_sharded_entry_point_fn, T=T, C=C, BTs=BTs, BTm=BTm, BCs=BCs, BCm=BCm, backend=backend, m=10,
282-
task=task, scheme=scheme)
283-
run_cal(
284-
T=T,
285-
C=C,
286-
BTs=BTs,
287-
BCs=BCs,
288-
BTm=BTm,
289-
BCm=BCm,
290-
backend=backend,
291-
m=10,
292-
task='LM-Solver 1-iter 1-CG-iter',
293-
scheme=scheme
294-
)
284+
try:
285+
run(build_sharded_entry_point_fn, T=T, C=C, BTs=BTs, BTm=BTm, BCs=BCs, BCm=BCm, backend=backend,
286+
m=10,
287+
task=task, scheme=scheme)
288+
except XlaRuntimeError:
289+
dsa_logger.info(f"{task} {scheme}: {backend}: XlaRuntimeError")
290+
try:
291+
run_cal(
292+
T=T,
293+
C=C,
294+
BTs=BTs,
295+
BCs=BCs,
296+
BTm=BTm,
297+
BCm=BCm, backend=backend,
298+
m=10,
299+
task='LM-Solver 1-iter 1-CG-iter',
300+
scheme=scheme
301+
)
302+
except XlaRuntimeError:
303+
dsa_logger.info(f"LM-Solver 1-iter 1-CG-iter {scheme}: {backend}: XlaRuntimeError")
295304

296305

297306
if __name__ == '__main__':

0 commit comments

Comments
 (0)