Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def generate_second_node_pool_name(
container_name="jax-tpu-worker",
image=(
"asia-northeast1-docker.pkg.dev/cienet-cmcs/yuna-docker/"
"tpu-info:v0.5.1"
"tpu-info:v0.8.1"
),
tpu_cores_per_pod=4,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
tpu_topology="4x4",
container_name="jax-tpu-worker",
image="asia-northeast1-docker.pkg.dev/cienet-cmcs/"
"yuna-docker/tpu-info:v0.5.1",
"yuna-docker/tpu-info:v0.8.1",
tpu_cores_per_pod=4,
)

Expand Down
65 changes: 39 additions & 26 deletions dags/tpu_observability/utils/jobset_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

from airflow.decorators import task
from airflow.exceptions import AirflowFailException
from google.cloud.monitoring_v3 import types
import kubernetes

from dags.tpu_observability.utils.node_pool_util import Info as node_pool_info
from dags.tpu_observability.utils import subprocess_util as subprocess
from dags.tpu_observability.utils.gcp_util import query_time_series
from dags.tpu_observability.utils.node_pool_util import Info as node_pool_info
from dags.tpu_observability.utils.time_util import TimeUtil
from google.cloud.monitoring_v3 import types
import kubernetes
from xlml.utils import gke


Expand All @@ -58,25 +58,38 @@ class Workload:
os.environ.setdefault("JAX_USE_PJIT", "true")
jax.distributed.initialize()

idx = jax.process_index()
global_devices = jax.devices()
print(
f"[Host {jax.process_index()}] "
f"Got {len(global_devices)} global devices"
)
print(f"[Host {idx}] Got {len(global_devices)} global devices")
mesh = Mesh(global_devices, ("x",))

print(f"[Host {jax.process_index()}] Allocating data...")
size = 32832
x_global = jnp.ones((size, size), dtype=jnp.float32)
y_global = jnp.ones((size, size), dtype=jnp.float32)
print(f"[Host {idx}] Allocating data...")
print(f"[Host {idx}] Defining sharding...")
size = 32768
global_shape = (size, size)
sharding = NamedSharding(
mesh, jax.sharding.PartitionSpec("x", None)
)

print(f"[Host {idx}] Creating sharded data directly on devices...")

def ones_callback(index):
resolved_indices = [
s.indices(global_shape[i]) for i, s in enumerate(index)
]
local_shape = tuple(
stop - start for start, stop, step in resolved_indices
)

return jnp.ones(local_shape, dtype=jnp.float32)

print(f"[Host {jax.process_index()}] Sharding data...")
sharding = NamedSharding(mesh, jax.sharding.PartitionSpec("x", None))
x = jax.device_put(x_global, sharding)
y = jax.device_put(y_global, sharding)
print(f"[Host {jax.process_index()}] Data on device")
x = jax.make_array_from_callback(
global_shape, sharding, ones_callback)
y = jax.make_array_from_callback(
global_shape, sharding, ones_callback)

print(f"[Host {idx}] Data on device")

# ========= Define heavy workload =========
@pjit
def matmul_ultra_heavy(x, y):
tmp1 = jnp.dot(x, y)
Expand All @@ -93,7 +106,6 @@ def matmul_ultra_heavy(x, y):
print(f"[Host {jax.process_index()}] Starting benchmark...")

start = time.time()
# Remember to control loop time to control experiment time
for i in range(1_000_000):
result = matmul_ultra_heavy(x, y)
result.block_until_ready()
Expand Down Expand Up @@ -485,17 +497,18 @@ def delete_one_random_pod(
Defaults to "default".

Raises:
AirflowFailException: If no running pods are found in the specified namespace.
AirflowFailException: If no running pods are found in the specified
namespace.
"""
running_pods = get_running_pods(node_pool=node_pool, namespace=namespace)
if not running_pods:
logging.error(f"No running pods found in namespace: {namespace}")
logging.error("No running pods found in namespace: %s", namespace)
raise AirflowFailException(
f"No running pods found in namespace: {namespace}"
)

target_pod = random.choice(running_pods)
logging.info(f"Targeting pod for deletion: {target_pod}")
logging.info("Targeting pod for deletion: %s", target_pod)

with tempfile.NamedTemporaryFile() as temp_config_file:
env = os.environ.copy()
Expand All @@ -509,7 +522,7 @@ def delete_one_random_pod(
])

subprocess.run_exec(cmd, env=env)
logging.info(f"Successfully initiated deletion for pod: {target_pod}")
logging.info("Successfully initiated deletion for pod: %s", target_pod)


@task.sensor(poke_interval=30, timeout=900, mode="poke")
Expand Down Expand Up @@ -583,8 +596,7 @@ def wait_for_jobset_started(
def wait_for_jobset_ttr_to_be_found(
node_pool: node_pool_info, jobset_name: str
) -> bool:
"""
Polls the jobset time_between_interruptions metric.
"""Polls the jobset time_between_interruptions metric.

A sensor task which polls the jobset time_between_interruptions metric
every 60 seconds for 60 minutes. 60 minutes is used here since this
Expand All @@ -594,8 +606,9 @@ def wait_for_jobset_ttr_to_be_found(
impractical for the test to run longer.

Args:
info(Info): An instance of the Info class that encapsulates
the configuration and metadata of a GKE node pool and workload.
node_pool: An instance of the Info class that encapsulates
the configuration and metadata of a GKE node pool and workload.
jobset_name: The name of the JobSet.
"""
now = datetime.datetime.now()

Expand Down
Loading