3939from jax .experimental .array_serialization import tensorstore_impl
4040
4141from axlearn .common import utils
42- from axlearn .common .array_serialization import _async_deserialize
42+ from axlearn .common .array_serialization import GlobalAsyncCheckpointManager , _async_deserialize
4343from axlearn .common .checkpointer import parse_step_from_dir , read_index_file
4444from axlearn .common .utils import TensorSpec , infer_mesh_shape
4545
@@ -201,7 +201,7 @@ def create_checkpoint_spec_from_state(ckpt_dir: str, state_spec: dict):
201201 """Create checkpoint spec following the pattern from TensorStoreStateStorage._get_spec."""
202202
203203 tensorstore_specs = []
204- shapes = []
204+ global_shapes = []
205205 dtypes = []
206206 shardings = []
207207
@@ -227,11 +227,11 @@ def create_checkpoint_spec_from_state(ckpt_dir: str, state_spec: dict):
227227 sharding = jax .sharding .NamedSharding (mesh , partition_spec )
228228
229229 tensorstore_specs .append (tensorstore_spec )
230- shapes .append (value .shape )
230+ global_shapes .append (value .shape )
231231 dtypes .append (dtype )
232232 shardings .append (sharding )
233233
234- return tensorstore_specs , shardings , shapes , dtypes
234+ return tensorstore_specs , shardings , global_shapes , dtypes
235235
236236
237237def cleanup_loaded_arrays (loaded_arrays : list ) -> None :
@@ -266,122 +266,66 @@ def cleanup_loaded_arrays(loaded_arrays: list) -> None:
266266 print ("Cleanup complete." )
267267
268268
269- def _default_deserialize (
270- shardings : Sequence [jax .sharding .NamedSharding ],
269+ def load_model_default (
271270 tensorstore_specs : Sequence [Dict [str , Any ]],
271+ shardings : Sequence [jax .sharding .NamedSharding ],
272272 global_shapes : Sequence [tuple ],
273273 dtypes : Sequence [jnp .dtype ],
274274):
275- # concurrent_bytes = 1099511627776
276- concurrent_bytes = 34359738368 * 6 # multiple of 32GB
277- # Object should be created once per process.
278- # pylint: disable=protected-access
279- byte_limiter = tensorstore_impl ._LimitInFlightBytes (concurrent_bytes )
280- h2d_limiter = tensorstore_impl ._LimitInFlightBytes (34359738368 )
281- thread_pool = ThreadPoolExecutor (1 )
282- multi_thread_pool = ThreadPoolExecutor (2 )
283-
284- future_arrays = jax .tree .map (
285- functools .partial (
286- _async_deserialize ,
287- byte_limiter = byte_limiter ,
288- h2d_limiter = h2d_limiter ,
289- single_thread_pool = thread_pool ,
290- multi_thread_pool = multi_thread_pool ,
291- ),
292- shardings ,
293- tensorstore_specs ,
294- global_shapes ,
295- dtypes ,
275+ """Load model using default method (direct to TPU)."""
276+ print ("Preloading checkpoint to TPU HBM..." )
277+ start_time = time .perf_counter ()
278+
279+ manager = GlobalAsyncCheckpointManager ()
280+ restored_values = manager .deserialize (
281+ shardings = shardings ,
282+ tensorstore_specs = tensorstore_specs ,
283+ global_shapes = global_shapes ,
284+ dtypes = dtypes ,
285+ concurrent_gb = 192 ,
296286 )
297287
298- async def gather_func ():
299- return await asyncio .gather (* future_arrays )
300-
301- result = asyncio .run (gather_func ())
302- return result
303-
304-
305- def load_model_default (ckpt_path : str ):
306- """Main function to preload a model from GCS checkpoint."""
307- step = parse_step_from_dir (ckpt_path )
308- print (f"Starting model preload from: { ckpt_path } (step { step } )" )
309-
310- if not ckpt_path .startswith ("gs://" ):
311- raise ValueError (f"Only GCS paths (gs://) are supported, got: { ckpt_path } " )
312-
313- with create_mesh ():
314- print ("Reading checkpoint structure..." )
315- state_spec = create_state_spec_from_checkpoint (ckpt_path )
316-
317- print (f"Found { len (jax .tree_util .tree_leaves (state_spec ))} tensors in checkpoint" )
318-
319- tensorstore_specs , shardings , shapes , dtypes = create_checkpoint_spec_from_state (
320- ckpt_path , state_spec
321- )
322-
323- print ("Preloading checkpoint to TPU memory..." )
324- start_time = time .perf_counter ()
325-
326- restored_values = _default_deserialize (
327- shardings = shardings ,
328- tensorstore_specs = tensorstore_specs ,
329- global_shapes = shapes ,
330- dtypes = dtypes ,
331- )
332-
333- preload_time = time .perf_counter () - start_time
334- print (f"Preload completed in { preload_time :.2f} seconds" )
335- print (f"Preloaded { len (restored_values )} arrays" )
336-
337- return restored_values
288+ preload_time = time .perf_counter () - start_time
289+ print (f"Preload completed in { preload_time :.2f} seconds" )
290+ print (f"Preloaded { len (restored_values )} arrays" )
338291
292+ return restored_values
339293
340- def load_model_colocated (ckpt_path : str ):
341- """Main function to preload a model from GCS checkpoint."""
342- step = parse_step_from_dir (ckpt_path )
343- print (f"Starting model preload from: { ckpt_path } (step { step } )" )
344294
345- if not ckpt_path .startswith ("gs://" ):
346- raise ValueError (f"Only GCS paths (gs://) are supported, got: { ckpt_path } " )
347-
348- with create_mesh ():
349- print ("Reading checkpoint structure..." )
350- state_spec = create_state_spec_from_checkpoint (ckpt_path )
351-
352- print (f"Found { len (jax .tree_util .tree_leaves (state_spec ))} tensors in checkpoint" )
353-
354- tensorstore_specs , shardings , shapes , dtypes = create_checkpoint_spec_from_state (
355- ckpt_path , state_spec
356- )
357-
358- print ("Preloading checkpoint to CPU memory..." )
359- start_time = time .perf_counter ()
360-
361- preloaded_values = _colocated_deserialize (
362- shardings = shardings ,
363- tensorstore_specs = tensorstore_specs ,
364- global_shapes = shapes ,
365- dtypes = dtypes ,
366- )
367- # for x in preloaded_values:
368- # x.block_until_ready()
295+ def load_model_colocated (
296+ tensorstore_specs : Sequence [Dict [str , Any ]],
297+ shardings : Sequence [jax .sharding .NamedSharding ],
298+ global_shapes : Sequence [tuple ],
299+ dtypes : Sequence [jnp .dtype ],
300+ ):
301+ """Load model using colocated Python (CPU preload then transfer to TPU)."""
302+ print ("Preloading checkpoint to CPU memory..." )
303+ start_time = time .perf_counter ()
304+
305+ preloaded_values = _colocated_deserialize (
306+ shardings = shardings ,
307+ tensorstore_specs = tensorstore_specs ,
308+ global_shapes = global_shapes ,
309+ dtypes = dtypes ,
310+ )
311+ # for x in preloaded_values:
312+ # x.block_until_ready()
369313
370- preload_time = time .perf_counter () - start_time
371- print (f"Preload completed in { preload_time :.2f} seconds" )
372- print (f"Preloaded { len (preloaded_values )} arrays" )
314+ preload_time = time .perf_counter () - start_time
315+ print (f"Preload completed in { preload_time :.2f} seconds" )
316+ print (f"Preloaded { len (preloaded_values )} arrays" )
373317
374- print ("Transferring arrays to TPU..." )
375- start_time = time .perf_counter ()
318+ print ("Transferring arrays to TPU..." )
319+ start_time = time .perf_counter ()
376320
377- restored_values = [jax .device_put (x , s ) for x , s in zip (preloaded_values , shardings )]
378- for x in restored_values :
379- x .block_until_ready ()
321+ restored_values = [jax .device_put (x , s ) for x , s in zip (preloaded_values , shardings )]
322+ for x in restored_values :
323+ x .block_until_ready ()
380324
381- transfer_time = time .perf_counter () - start_time
382- print (f"Transfer completed in { transfer_time :.2f} seconds" )
325+ transfer_time = time .perf_counter () - start_time
326+ print (f"Transfer completed in { transfer_time :.2f} seconds" )
383327
384- return restored_values
328+ return restored_values
385329
386330
387331def main ():
@@ -422,23 +366,44 @@ def main():
422366 loader_fn = load_model_default
423367 print (f"--- Running { args .method } benchmark ---" )
424368
369+ # Validate checkpoint path
370+ if not args .ckpt_path .startswith ("gs://" ):
371+ raise ValueError (f"Only GCS paths (gs://) are supported, got: { args .ckpt_path } " )
425372 profile_dir = None
426373 if args .profile :
427374 # Create timestamped profile directory (minute-level granularity)
428375 timestamp = datetime .now ().strftime ("%Y%m%d%H%M" )
429- profile_dir = (
430- f"{ args .ckpt_path .split ("/checkpoints" )[0 ]} /profiles/{ args .method } _{ timestamp } /"
431- )
376+ base_path = args .ckpt_path .split ("/checkpoints" )[0 ]
377+ profile_dir = f"{ base_path } /profiles/{ args .method } _{ timestamp } /"
432378 print (f"Profiling enabled - results will be saved to { profile_dir } " )
433379
380+ step = parse_step_from_dir (args .ckpt_path )
381+ print (f"Starting model preload from: { args .ckpt_path } (step { step } )" )
382+
383+ # Read checkpoint structure (doesn't need mesh)
384+ print ("Reading checkpoint structure..." )
385+ state_spec = create_state_spec_from_checkpoint (args .ckpt_path )
386+ print (f"Found { len (jax .tree_util .tree_leaves (state_spec ))} tensors in checkpoint" )
387+
434388 loaded_values = None
435389 try :
436- with maybe_profile (args .profile , profile_dir ):
437- start_time = time .perf_counter ()
438- loaded_values = loader_fn (ckpt_path = args .ckpt_path )
439- print (f"✅ Successfully loaded model from { args .ckpt_path } " )
440- print (f"Deserialize took { time .perf_counter () - start_time :.2f} seconds" )
441- print (f" Total parameters: { sum (x .size for x in loaded_values ):,} " )
390+ with create_mesh ():
391+ # Create checkpoint specs (needs mesh)
392+ tensorstore_specs , shardings , global_shapes , dtypes = create_checkpoint_spec_from_state (
393+ args .ckpt_path , state_spec
394+ )
395+
396+ with maybe_profile (args .profile , profile_dir ):
397+ start_time = time .perf_counter ()
398+ loaded_values = loader_fn (
399+ tensorstore_specs = tensorstore_specs ,
400+ shardings = shardings ,
401+ global_shapes = global_shapes ,
402+ dtypes = dtypes ,
403+ )
404+ print (f"✅ Successfully loaded model from { args .ckpt_path } " )
405+ print (f"Deserialize took { time .perf_counter () - start_time :.2f} seconds" )
406+ print (f" Total parameters: { sum (x .size for x in loaded_values ):,} " )
442407 finally :
443408 # Always clean up, even if benchmark fails
444409 if loaded_values is not None :
0 commit comments