Skip to content

Commit 29bdcf2

Browse files
committed
format fix
1 parent b9a29e7 commit 29bdcf2

File tree

1 file changed

+34
-25
lines changed

1 file changed

+34
-25
lines changed

dags/tpu_observability/utils/jobset_util.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727

2828
from airflow.decorators import task
2929
from airflow.exceptions import AirflowFailException
30-
from google.cloud.monitoring_v3 import types
31-
import kubernetes
3230

33-
from dags.tpu_observability.utils.node_pool_util import Info as node_pool_info
3431
from dags.tpu_observability.utils import subprocess_util as subprocess
3532
from dags.tpu_observability.utils.gcp_util import query_time_series
33+
from dags.tpu_observability.utils.node_pool_util import Info as node_pool_info
3634
from dags.tpu_observability.utils.time_util import TimeUtil
35+
from google.cloud.monitoring_v3 import types
36+
import kubernetes
3737
from xlml.utils import gke
3838

3939

@@ -58,30 +58,38 @@ class Workload:
5858
os.environ.setdefault("JAX_USE_PJIT", "true")
5959
jax.distributed.initialize()
6060
61+
idx = jax.process_index()
6162
global_devices = jax.devices()
62-
print(f"[Host {jax.process_index()}] Got {len(global_devices)} global devices")
63+
print(f"[Host {idx}] Got {len(global_devices)} global devices")
6364
mesh = Mesh(global_devices, ("x",))
6465
65-
print(f"[Host {jax.process_index()}] Allocating data...")
66-
print(f"[Host {jax.process_index()}] Defining sharding...")
66+
print(f"[Host {idx}] Allocating data...")
67+
print(f"[Host {idx}] Defining sharding...")
6768
size = 32768
6869
global_shape = (size, size)
69-
sharding = NamedSharding(mesh, jax.sharding.PartitionSpec("x", None))
70+
sharding = NamedSharding(
71+
mesh, jax.sharding.PartitionSpec("x", None)
72+
)
7073
71-
print(f"[Host {jax.process_index()}] Creating sharded data directly on devices...")
74+
print(f"[Host {idx}] Creating sharded data directly on devices...")
7275
7376
def ones_callback(index):
74-
resolved_indices = [s.indices(global_shape[i]) for i, s in enumerate(index)]
75-
local_shape = tuple(stop - start for start, stop, step in resolved_indices)
77+
resolved_indices = [
78+
s.indices(global_shape[i]) for i, s in enumerate(index)
79+
]
80+
local_shape = tuple(
81+
stop - start for start, stop, step in resolved_indices
82+
)
7683
7784
return jnp.ones(local_shape, dtype=jnp.float32)
7885
79-
x = jax.make_array_from_callback(global_shape, sharding, ones_callback)
80-
y = jax.make_array_from_callback(global_shape, sharding, ones_callback)
86+
x = jax.make_array_from_callback(
87+
global_shape, sharding, ones_callback)
88+
y = jax.make_array_from_callback(
89+
global_shape, sharding, ones_callback)
8190
82-
print(f"[Host {jax.process_index()}] Data on device")
91+
print(f"[Host {idx}] Data on device")
8392
84-
# ========= Define heavy workload =========
8593
@pjit
8694
def matmul_ultra_heavy(x, y):
8795
tmp1 = jnp.dot(x, y)
@@ -98,15 +106,15 @@ def matmul_ultra_heavy(x, y):
98106
print(f"[Host {jax.process_index()}] Starting benchmark...")
99107
100108
start = time.time()
101-
for i in range(1_000_000): # Remember to control loop time to control experiment time
109+
for i in range(1_000_000):
102110
result = matmul_ultra_heavy(x, y)
103111
result.block_until_ready()
104112
end = time.time()
105113
106114
if jax.process_index() == 0:
107115
print(f"Total time: {end - start:.2f} seconds (on full v6e-16)")
108-
'
109-
echo "sleep..."
116+
' &&
117+
echo "Workload finished, sleeping now..." &&
110118
sleep 10000
111119
"""
112120
),
@@ -489,17 +497,18 @@ def delete_one_random_pod(
489497
Defaults to "default".
490498
491499
Raises:
492-
AirflowFailException: If no running pods are found in the specified namespace.
500+
AirflowFailException: If no running pods are found in the specified
501+
namespace.
493502
"""
494503
running_pods = get_running_pods(node_pool=node_pool, namespace=namespace)
495504
if not running_pods:
496-
logging.error(f"No running pods found in namespace: {namespace}")
505+
logging.error("No running pods found in namespace: %s", namespace)
497506
raise AirflowFailException(
498507
f"No running pods found in namespace: {namespace}"
499508
)
500509

501510
target_pod = random.choice(running_pods)
502-
logging.info(f"Targeting pod for deletion: {target_pod}")
511+
logging.info("Targeting pod for deletion: %s", target_pod)
503512

504513
with tempfile.NamedTemporaryFile() as temp_config_file:
505514
env = os.environ.copy()
@@ -513,7 +522,7 @@ def delete_one_random_pod(
513522
])
514523

515524
subprocess.run_exec(cmd, env=env)
516-
logging.info(f"Successfully initiated deletion for pod: {target_pod}")
525+
logging.info("Successfully initiated deletion for pod: %s", target_pod)
517526

518527

519528
@task.sensor(poke_interval=30, timeout=900, mode="poke")
@@ -587,8 +596,7 @@ def wait_for_jobset_started(
587596
def wait_for_jobset_ttr_to_be_found(
588597
node_pool: node_pool_info, jobset_name: str
589598
) -> bool:
590-
"""
591-
Polls the jobset time_between_interruptions metric.
599+
"""Polls the jobset time_between_interruptions metric.
592600
593601
A sensor task which polls the jobset time_between_interruptions metric
594602
every 60 seconds for 60 minutes. 60 minutes is used here since this
@@ -598,8 +606,9 @@ def wait_for_jobset_ttr_to_be_found(
598606
impractical for the test to run longer.
599607
600608
Args:
601-
info(Info): An instance of the Info class that encapsulates
602-
the configuration and metadata of a GKE node pool and workload.
609+
node_pool: An instance of the Info class that encapsulates
610+
the configuration and metadata of a GKE node pool and workload.
611+
jobset_name: The name of the JobSet.
603612
"""
604613
now = datetime.datetime.now()
605614

0 commit comments

Comments
 (0)