2727
2828from airflow .decorators import task
2929from airflow .exceptions import AirflowFailException
30- from google .cloud .monitoring_v3 import types
31- import kubernetes
3230
33- from dags .tpu_observability .utils .node_pool_util import Info as node_pool_info
3431from dags .tpu_observability .utils import subprocess_util as subprocess
3532from dags .tpu_observability .utils .gcp_util import query_time_series
33+ from dags .tpu_observability .utils .node_pool_util import Info as node_pool_info
3634from dags .tpu_observability .utils .time_util import TimeUtil
35+ from google .cloud .monitoring_v3 import types
36+ import kubernetes
3737from xlml .utils import gke
3838
3939
@@ -58,30 +58,38 @@ class Workload:
5858 os.environ.setdefault("JAX_USE_PJIT", "true")
5959 jax.distributed.initialize()
6060
61+ idx = jax.process_index()
6162 global_devices = jax.devices()
62- print(f"[Host {jax.process_index() }] Got {len(global_devices)} global devices")
63+ print(f"[Host {idx }] Got {len(global_devices)} global devices")
6364 mesh = Mesh(global_devices, ("x",))
6465
65- print(f"[Host {jax.process_index() }] Allocating data...")
66- print(f"[Host {jax.process_index() }] Defining sharding...")
66+ print(f"[Host {idx }] Allocating data...")
67+ print(f"[Host {idx }] Defining sharding...")
6768 size = 32768
6869 global_shape = (size, size)
69- sharding = NamedSharding(mesh, jax.sharding.PartitionSpec("x", None))
70+ sharding = NamedSharding(
71+ mesh, jax.sharding.PartitionSpec("x", None)
72+ )
7073
71- print(f"[Host {jax.process_index() }] Creating sharded data directly on devices...")
74+ print(f"[Host {idx }] Creating sharded data directly on devices...")
7275
7376 def ones_callback(index):
74- resolved_indices = [s.indices(global_shape[i]) for i, s in enumerate(index)]
75- local_shape = tuple(stop - start for start, stop, step in resolved_indices)
77+ resolved_indices = [
78+ s.indices(global_shape[i]) for i, s in enumerate(index)
79+ ]
80+ local_shape = tuple(
81+ stop - start for start, stop, step in resolved_indices
82+ )
7683
7784 return jnp.ones(local_shape, dtype=jnp.float32)
7885
79- x = jax.make_array_from_callback(global_shape, sharding, ones_callback)
80- y = jax.make_array_from_callback(global_shape, sharding, ones_callback)
86+ x = jax.make_array_from_callback(
87+ global_shape, sharding, ones_callback)
88+ y = jax.make_array_from_callback(
89+ global_shape, sharding, ones_callback)
8190
82- print(f"[Host {jax.process_index() }] Data on device")
91+ print(f"[Host {idx }] Data on device")
8392
84- # ========= Define heavy workload =========
8593 @pjit
8694 def matmul_ultra_heavy(x, y):
8795 tmp1 = jnp.dot(x, y)
@@ -98,15 +106,15 @@ def matmul_ultra_heavy(x, y):
98106 print(f"[Host {jax.process_index()}] Starting benchmark...")
99107
100108 start = time.time()
101- for i in range(1_000_000): # Remember to control loop time to control experiment time
109+ for i in range(1_000_000):
102110 result = matmul_ultra_heavy(x, y)
103111 result.block_until_ready()
104112 end = time.time()
105113
106114 if jax.process_index() == 0:
107115 print(f"Total time: {end - start:.2f} seconds (on full v6e-16)")
108- '
109- echo "sleep ..."
116+ ' &&
117+ echo "Workload finished, sleeping now ..." &&
110118 sleep 10000
111119 """
112120 ),
@@ -489,17 +497,18 @@ def delete_one_random_pod(
489497 Defaults to "default".
490498
491499 Raises:
492- AirflowFailException: If no running pods are found in the specified namespace.
500+ AirflowFailException: If no running pods are found in the specified
501+ namespace.
493502 """
494503 running_pods = get_running_pods (node_pool = node_pool , namespace = namespace )
495504 if not running_pods :
496- logging .error (f "No running pods found in namespace: { namespace } " )
505+ logging .error ("No running pods found in namespace: %s" , namespace )
497506 raise AirflowFailException (
498507 f"No running pods found in namespace: { namespace } "
499508 )
500509
501510 target_pod = random .choice (running_pods )
502- logging .info (f "Targeting pod for deletion: { target_pod } " )
511+ logging .info ("Targeting pod for deletion: %s" , target_pod )
503512
504513 with tempfile .NamedTemporaryFile () as temp_config_file :
505514 env = os .environ .copy ()
@@ -513,7 +522,7 @@ def delete_one_random_pod(
513522 ])
514523
515524 subprocess .run_exec (cmd , env = env )
516- logging .info (f "Successfully initiated deletion for pod: { target_pod } " )
525+ logging .info ("Successfully initiated deletion for pod: %s" , target_pod )
517526
518527
519528@task .sensor (poke_interval = 30 , timeout = 900 , mode = "poke" )
@@ -587,8 +596,7 @@ def wait_for_jobset_started(
587596def wait_for_jobset_ttr_to_be_found (
588597 node_pool : node_pool_info , jobset_name : str
589598) -> bool :
590- """
591- Polls the jobset time_between_interruptions metric.
599+ """Polls the jobset time_between_interruptions metric.
592600
593601 A sensor task which polls the jobset time_between_interruptions metric
594602 every 60 seconds for 60 minutes. 60 minutes is used here since this
@@ -598,8 +606,9 @@ def wait_for_jobset_ttr_to_be_found(
598606 impractical for the test to run longer.
599607
600608 Args:
601- info(Info): An instance of the Info class that encapsulates
602- the configuration and metadata of a GKE node pool and workload.
609+ node_pool: An instance of the Info class that encapsulates
610+ the configuration and metadata of a GKE node pool and workload.
611+ jobset_name: The name of the JobSet.
603612 """
604613 now = datetime .datetime .now ()
605614
0 commit comments