Skip to content

Commit 2053787

Browse files
committed
fix: update workload to be compatible with latest Jax/LibTPU
Due to a recent Jax upgrade, previous workload versions are no longer executable in the latest environment. Changes: - Updated base image to include the latest Jax, tpu-info, and libtpu. - Refactored workload code to ensure compatibility with the updated Jax API.
1 parent 0f0e4e2 commit 2053787

File tree

3 files changed

+21
-17
lines changed

3 files changed

+21
-17
lines changed

dags/tpu_observability/tpu_info_format_validation_dags.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def generate_second_node_pool_name(
353353
container_name="jax-tpu-worker",
354354
image=(
355355
"asia-northeast1-docker.pkg.dev/cienet-cmcs/yuna-docker/"
356-
"tpu-info:v0.5.1"
356+
"tpu-info:v0.8.1"
357357
),
358358
tpu_cores_per_pod=4,
359359
)

dags/tpu_observability/tpu_sdk_monitoring_validation_dag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def validate_monitoring_sdk(info: node_pool.Info, pod_name: str) -> None:
131131
tpu_topology="4x4",
132132
container_name="jax-tpu-worker",
133133
image="asia-northeast1-docker.pkg.dev/cienet-cmcs/"
134-
"yuna-docker/tpu-info:v0.5.1",
134+
"yuna-docker/tpu-info:v0.8.1",
135135
tpu_cores_per_pod=4,
136136
)
137137

dags/tpu_observability/utils/jobset_util.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,26 @@ class Workload:
5959
jax.distributed.initialize()
6060
6161
global_devices = jax.devices()
62-
print(
63-
f"[Host {jax.process_index()}] "
64-
f"Got {len(global_devices)} global devices"
65-
)
62+
print(f"[Host {jax.process_index()}] Got {len(global_devices)} global devices")
6663
mesh = Mesh(global_devices, ("x",))
6764
6865
print(f"[Host {jax.process_index()}] Allocating data...")
69-
size = 32832
70-
x_global = jnp.ones((size, size), dtype=jnp.float32)
71-
y_global = jnp.ones((size, size), dtype=jnp.float32)
72-
73-
print(f"[Host {jax.process_index()}] Sharding data...")
66+
print(f"[Host {jax.process_index()}] Defining sharding...")
67+
size = 32768
68+
global_shape = (size, size)
7469
sharding = NamedSharding(mesh, jax.sharding.PartitionSpec("x", None))
75-
x = jax.device_put(x_global, sharding)
76-
y = jax.device_put(y_global, sharding)
70+
71+
print(f"[Host {jax.process_index()}] Creating sharded data directly on devices...")
72+
73+
def ones_callback(index):
74+
resolved_indices = [s.indices(global_shape[i]) for i, s in enumerate(index)]
75+
local_shape = tuple(stop - start for start, stop, step in resolved_indices)
76+
77+
return jnp.ones(local_shape, dtype=jnp.float32)
78+
79+
x = jax.make_array_from_callback(global_shape, sharding, ones_callback)
80+
y = jax.make_array_from_callback(global_shape, sharding, ones_callback)
81+
7782
print(f"[Host {jax.process_index()}] Data on device")
7883
7984
# ========= Define heavy workload =========
@@ -93,16 +98,15 @@ def matmul_ultra_heavy(x, y):
9398
print(f"[Host {jax.process_index()}] Starting benchmark...")
9499
95100
start = time.time()
96-
# Remember to control loop time to control experiment time
97-
for i in range(1_000_000):
101+
for i in range(1_000_000): # Remember to control loop time to control experiment time
98102
result = matmul_ultra_heavy(x, y)
99103
result.block_until_ready()
100104
end = time.time()
101105
102106
if jax.process_index() == 0:
103107
print(f"Total time: {end - start:.2f} seconds (on full v6e-16)")
104-
' &&
105-
echo "Workload finished, sleeping now..." &&
108+
'
109+
echo "sleep..."
106110
sleep 10000
107111
"""
108112
),

0 commit comments

Comments
 (0)