@@ -231,16 +231,20 @@ def _shard_mutable_array(xs, shardings, layouts, copy_semantics):
231231def batched_device_put (aval : core .ShapedArray ,
232232 sharding : JSharding , xs : Sequence [Any ],
233233 devices : Sequence [jax .Device ], committed : bool = True ):
234- from jax ._src import array
235-
236- bufs = [x for x , d in safe_zip (xs , devices )
237- if (isinstance (x , array .ArrayImpl ) and
238- dispatch .is_single_device_sharding (x .sharding ) and
239- x .devices () == {d })]
240- if len (bufs ) == len (xs ):
241- return array .ArrayImpl (
242- aval , sharding , bufs , committed = committed , _skip_checks = True )
243- return xc .batched_device_put (aval , sharding , xs , list (devices ), committed )
234+ util .test_event ("batched_device_put_start" )
235+ try :
236+ from jax ._src import array
237+
238+ bufs = [x for x , d in safe_zip (xs , devices )
239+ if (isinstance (x , array .ArrayImpl ) and
240+ dispatch .is_single_device_sharding (x .sharding ) and
241+ x .devices () == {d })]
242+ if len (bufs ) == len (xs ):
243+ return array .ArrayImpl (
244+ aval , sharding , bufs , committed = committed , _skip_checks = True )
245+ return xc .batched_device_put (aval , sharding , xs , list (devices ), committed )
246+ finally :
247+ util .test_event ("batched_device_put_end" )
244248
245249def _shard_aval (size , axis : int , aval ):
246250 try :
@@ -2850,6 +2854,7 @@ def from_hlo(name: str,
28502854 mesh = i .mesh
28512855 break
28522856
2857+ util .test_event ("pxla_cached_compilation" )
28532858 xla_executable = _cached_compilation (
28542859 hlo , name , mesh , spmd_lowering ,
28552860 tuple_args , auto_spmd_lowering , allow_prop_to_inputs ,
0 commit comments