Skip to content

Commit 4f59a66

Browse files
raytomatochanglan
authored andcommitted
[1/n] Refactor colocated python benchmark script
GitOrigin-RevId: 10bc8fb78e6180f3447db7bcf0177f7dfae950c6
1 parent 04d5fd7 commit 4f59a66

File tree

1 file changed

+80
-115
lines changed

1 file changed

+80
-115
lines changed

axlearn/cloud/gcp/examples/colocated_python_benchmark.py

Lines changed: 80 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from jax.experimental.array_serialization import tensorstore_impl
4040

4141
from axlearn.common import utils
42-
from axlearn.common.array_serialization import _async_deserialize
42+
from axlearn.common.array_serialization import GlobalAsyncCheckpointManager, _async_deserialize
4343
from axlearn.common.checkpointer import parse_step_from_dir, read_index_file
4444
from axlearn.common.utils import TensorSpec, infer_mesh_shape
4545

@@ -201,7 +201,7 @@ def create_checkpoint_spec_from_state(ckpt_dir: str, state_spec: dict):
201201
"""Create checkpoint spec following the pattern from TensorStoreStateStorage._get_spec."""
202202

203203
tensorstore_specs = []
204-
shapes = []
204+
global_shapes = []
205205
dtypes = []
206206
shardings = []
207207

@@ -227,11 +227,11 @@ def create_checkpoint_spec_from_state(ckpt_dir: str, state_spec: dict):
227227
sharding = jax.sharding.NamedSharding(mesh, partition_spec)
228228

229229
tensorstore_specs.append(tensorstore_spec)
230-
shapes.append(value.shape)
230+
global_shapes.append(value.shape)
231231
dtypes.append(dtype)
232232
shardings.append(sharding)
233233

234-
return tensorstore_specs, shardings, shapes, dtypes
234+
return tensorstore_specs, shardings, global_shapes, dtypes
235235

236236

237237
def cleanup_loaded_arrays(loaded_arrays: list) -> None:
@@ -266,122 +266,66 @@ def cleanup_loaded_arrays(loaded_arrays: list) -> None:
266266
print("Cleanup complete.")
267267

268268

269-
def _default_deserialize(
270-
shardings: Sequence[jax.sharding.NamedSharding],
269+
def load_model_default(
271270
tensorstore_specs: Sequence[Dict[str, Any]],
271+
shardings: Sequence[jax.sharding.NamedSharding],
272272
global_shapes: Sequence[tuple],
273273
dtypes: Sequence[jnp.dtype],
274274
):
275-
# concurrent_bytes = 1099511627776
276-
concurrent_bytes = 34359738368 * 6 # multiple of 32GB
277-
# Object should be created once per process.
278-
# pylint: disable=protected-access
279-
byte_limiter = tensorstore_impl._LimitInFlightBytes(concurrent_bytes)
280-
h2d_limiter = tensorstore_impl._LimitInFlightBytes(34359738368)
281-
thread_pool = ThreadPoolExecutor(1)
282-
multi_thread_pool = ThreadPoolExecutor(2)
283-
284-
future_arrays = jax.tree.map(
285-
functools.partial(
286-
_async_deserialize,
287-
byte_limiter=byte_limiter,
288-
h2d_limiter=h2d_limiter,
289-
single_thread_pool=thread_pool,
290-
multi_thread_pool=multi_thread_pool,
291-
),
292-
shardings,
293-
tensorstore_specs,
294-
global_shapes,
295-
dtypes,
275+
"""Load model using default method (direct to TPU)."""
276+
print("Preloading checkpoint to TPU HBM...")
277+
start_time = time.perf_counter()
278+
279+
manager = GlobalAsyncCheckpointManager()
280+
restored_values = manager.deserialize(
281+
shardings=shardings,
282+
tensorstore_specs=tensorstore_specs,
283+
global_shapes=global_shapes,
284+
dtypes=dtypes,
285+
concurrent_gb=192,
296286
)
297287

298-
async def gather_func():
299-
return await asyncio.gather(*future_arrays)
300-
301-
result = asyncio.run(gather_func())
302-
return result
303-
304-
305-
def load_model_default(ckpt_path: str):
306-
"""Main function to preload a model from GCS checkpoint."""
307-
step = parse_step_from_dir(ckpt_path)
308-
print(f"Starting model preload from: {ckpt_path} (step {step})")
309-
310-
if not ckpt_path.startswith("gs://"):
311-
raise ValueError(f"Only GCS paths (gs://) are supported, got: {ckpt_path}")
312-
313-
with create_mesh():
314-
print("Reading checkpoint structure...")
315-
state_spec = create_state_spec_from_checkpoint(ckpt_path)
316-
317-
print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint")
318-
319-
tensorstore_specs, shardings, shapes, dtypes = create_checkpoint_spec_from_state(
320-
ckpt_path, state_spec
321-
)
322-
323-
print("Preloading checkpoint to TPU memory...")
324-
start_time = time.perf_counter()
325-
326-
restored_values = _default_deserialize(
327-
shardings=shardings,
328-
tensorstore_specs=tensorstore_specs,
329-
global_shapes=shapes,
330-
dtypes=dtypes,
331-
)
332-
333-
preload_time = time.perf_counter() - start_time
334-
print(f"Preload completed in {preload_time:.2f} seconds")
335-
print(f"Preloaded {len(restored_values)} arrays")
336-
337-
return restored_values
288+
preload_time = time.perf_counter() - start_time
289+
print(f"Preload completed in {preload_time:.2f} seconds")
290+
print(f"Preloaded {len(restored_values)} arrays")
338291

292+
return restored_values
339293

340-
def load_model_colocated(ckpt_path: str):
341-
"""Main function to preload a model from GCS checkpoint."""
342-
step = parse_step_from_dir(ckpt_path)
343-
print(f"Starting model preload from: {ckpt_path} (step {step})")
344294

345-
if not ckpt_path.startswith("gs://"):
346-
raise ValueError(f"Only GCS paths (gs://) are supported, got: {ckpt_path}")
347-
348-
with create_mesh():
349-
print("Reading checkpoint structure...")
350-
state_spec = create_state_spec_from_checkpoint(ckpt_path)
351-
352-
print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint")
353-
354-
tensorstore_specs, shardings, shapes, dtypes = create_checkpoint_spec_from_state(
355-
ckpt_path, state_spec
356-
)
357-
358-
print("Preloading checkpoint to CPU memory...")
359-
start_time = time.perf_counter()
360-
361-
preloaded_values = _colocated_deserialize(
362-
shardings=shardings,
363-
tensorstore_specs=tensorstore_specs,
364-
global_shapes=shapes,
365-
dtypes=dtypes,
366-
)
367-
# for x in preloaded_values:
368-
# x.block_until_ready()
295+
def load_model_colocated(
296+
tensorstore_specs: Sequence[Dict[str, Any]],
297+
shardings: Sequence[jax.sharding.NamedSharding],
298+
global_shapes: Sequence[tuple],
299+
dtypes: Sequence[jnp.dtype],
300+
):
301+
"""Load model using colocated Python (CPU preload then transfer to TPU)."""
302+
print("Preloading checkpoint to CPU memory...")
303+
start_time = time.perf_counter()
304+
305+
preloaded_values = _colocated_deserialize(
306+
shardings=shardings,
307+
tensorstore_specs=tensorstore_specs,
308+
global_shapes=global_shapes,
309+
dtypes=dtypes,
310+
)
311+
# for x in preloaded_values:
312+
# x.block_until_ready()
369313

370-
preload_time = time.perf_counter() - start_time
371-
print(f"Preload completed in {preload_time:.2f} seconds")
372-
print(f"Preloaded {len(preloaded_values)} arrays")
314+
preload_time = time.perf_counter() - start_time
315+
print(f"Preload completed in {preload_time:.2f} seconds")
316+
print(f"Preloaded {len(preloaded_values)} arrays")
373317

374-
print("Transferring arrays to TPU...")
375-
start_time = time.perf_counter()
318+
print("Transferring arrays to TPU...")
319+
start_time = time.perf_counter()
376320

377-
restored_values = [jax.device_put(x, s) for x, s in zip(preloaded_values, shardings)]
378-
for x in restored_values:
379-
x.block_until_ready()
321+
restored_values = [jax.device_put(x, s) for x, s in zip(preloaded_values, shardings)]
322+
for x in restored_values:
323+
x.block_until_ready()
380324

381-
transfer_time = time.perf_counter() - start_time
382-
print(f"Transfer completed in {transfer_time:.2f} seconds")
325+
transfer_time = time.perf_counter() - start_time
326+
print(f"Transfer completed in {transfer_time:.2f} seconds")
383327

384-
return restored_values
328+
return restored_values
385329

386330

387331
def main():
@@ -422,23 +366,44 @@ def main():
422366
loader_fn = load_model_default
423367
print(f"--- Running {args.method} benchmark ---")
424368

369+
# Validate checkpoint path
370+
if not args.ckpt_path.startswith("gs://"):
371+
raise ValueError(f"Only GCS paths (gs://) are supported, got: {args.ckpt_path}")
425372
profile_dir = None
426373
if args.profile:
427374
# Create timestamped profile directory (minute-level granularity)
428375
timestamp = datetime.now().strftime("%Y%m%d%H%M")
429-
profile_dir = (
430-
f"{args.ckpt_path.split("/checkpoints")[0]}/profiles/{args.method}_{timestamp}/"
431-
)
376+
base_path = args.ckpt_path.split("/checkpoints")[0]
377+
profile_dir = f"{base_path}/profiles/{args.method}_{timestamp}/"
432378
print(f"Profiling enabled - results will be saved to {profile_dir}")
433379

380+
step = parse_step_from_dir(args.ckpt_path)
381+
print(f"Starting model preload from: {args.ckpt_path} (step {step})")
382+
383+
# Read checkpoint structure (doesn't need mesh)
384+
print("Reading checkpoint structure...")
385+
state_spec = create_state_spec_from_checkpoint(args.ckpt_path)
386+
print(f"Found {len(jax.tree_util.tree_leaves(state_spec))} tensors in checkpoint")
387+
434388
loaded_values = None
435389
try:
436-
with maybe_profile(args.profile, profile_dir):
437-
start_time = time.perf_counter()
438-
loaded_values = loader_fn(ckpt_path=args.ckpt_path)
439-
print(f"✅ Successfully loaded model from {args.ckpt_path}")
440-
print(f"Deserialize took {time.perf_counter() - start_time:.2f} seconds")
441-
print(f" Total parameters: {sum(x.size for x in loaded_values):,}")
390+
with create_mesh():
391+
# Create checkpoint specs (needs mesh)
392+
tensorstore_specs, shardings, global_shapes, dtypes = create_checkpoint_spec_from_state(
393+
args.ckpt_path, state_spec
394+
)
395+
396+
with maybe_profile(args.profile, profile_dir):
397+
start_time = time.perf_counter()
398+
loaded_values = loader_fn(
399+
tensorstore_specs=tensorstore_specs,
400+
shardings=shardings,
401+
global_shapes=global_shapes,
402+
dtypes=dtypes,
403+
)
404+
print(f"✅ Successfully loaded model from {args.ckpt_path}")
405+
print(f"Deserialize took {time.perf_counter() - start_time:.2f} seconds")
406+
print(f" Total parameters: {sum(x.size for x in loaded_values):,}")
442407
finally:
443408
# Always clean up, even if benchmark fails
444409
if loaded_values is not None:

0 commit comments

Comments
 (0)