|
1 | 1 | import os |
2 | 2 | from functools import partial |
3 | 3 |
|
| 4 | +from jaxlib.xla_extension import XlaRuntimeError |
| 5 | + |
4 | 6 | os.environ['JAX_PLATFORMS'] = 'cuda,cpu' |
5 | 7 | os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0' |
6 | 8 | os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={8}" |
@@ -127,7 +129,8 @@ def prepare_data(D: int, T, C, BTs, BTm, BCs, BCm) -> Dict[str, Any]: |
127 | 129 | vis_data = np.zeros((T // BTm, B, C // BCm, 2, 2), dtype=mp_policy.vis_dtype) |
128 | 130 | gains = np.zeros((D, T // BTs, num_antennas, C // BCs, 2, 2), dtype=mp_policy.gain_dtype) |
129 | 131 |
|
130 | | - dsa_logger.info(f"Model size D * {vis_model.nbytes / D / 2 ** 20} MB, gain size D * {gains.nbytes / D / 2 ** 10} KB") |
| 132 | + dsa_logger.info( |
| 133 | + f"Model size D * {vis_model.nbytes / D / 2 ** 20} MB, gain size D * {gains.nbytes / D / 2 ** 10} KB") |
131 | 134 |
|
132 | 135 | return dict( |
133 | 136 | vis_model=vis_model, |
@@ -278,20 +281,26 @@ def main(): |
278 | 281 | for task, build_sharded_entry_point_fn in zip( |
279 | 282 | ['R(x)', 'J^T.R(x)', 'J^T.J.p'], |
280 | 283 | [build_sharded_entry_point_R, build_sharded_entry_point_JtR, build_sharded_entry_point_JtJp]): |
281 | | - run(build_sharded_entry_point_fn, T=T, C=C, BTs=BTs, BTm=BTm, BCs=BCs, BCm=BCm, backend=backend, m=10, |
282 | | - task=task, scheme=scheme) |
283 | | - run_cal( |
284 | | - T=T, |
285 | | - C=C, |
286 | | - BTs=BTs, |
287 | | - BCs=BCs, |
288 | | - BTm=BTm, |
289 | | - BCm=BCm, |
290 | | - backend=backend, |
291 | | - m=10, |
292 | | - task='LM-Solver 1-iter 1-CG-iter', |
293 | | - scheme=scheme |
294 | | - ) |
| 284 | + try: |
| 285 | + run(build_sharded_entry_point_fn, T=T, C=C, BTs=BTs, BTm=BTm, BCs=BCs, BCm=BCm, backend=backend, |
| 286 | + m=10, |
| 287 | + task=task, scheme=scheme) |
| 288 | + except XlaRuntimeError: |
| 289 | + dsa_logger.info(f"{task} {scheme}: {backend}: XlaRuntimeError") |
| 290 | + try: |
| 291 | + run_cal( |
| 292 | + T=T, |
| 293 | + C=C, |
| 294 | + BTs=BTs, |
| 295 | + BCs=BCs, |
| 296 | + BTm=BTm, |
| 297 | + BCm=BCm, backend=backend, |
| 298 | + m=10, |
| 299 | + task='LM-Solver 1-iter 1-CG-iter', |
| 300 | + scheme=scheme |
| 301 | + ) |
| 302 | + except XlaRuntimeError: |
| 303 | + dsa_logger.info(f"LM-Solver 1-iter 1-CG-iter {scheme}: {backend}: XlaRuntimeError") |
295 | 304 |
|
296 | 305 |
|
297 | 306 | if __name__ == '__main__': |
|
0 commit comments