Skip to content

Commit 8fdf3c2

Browse files
Do unboxing only on concrete states and not abstract_state (#242)
1 parent c24a86b commit 8fdf3c2

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,11 @@ def get_abstract_state(model, tx, config, mesh, weights_init_fn, training=True):
359359
state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules)
360360

361361
abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape()
362-
unboxed_sharded_abstract_state = unbox_logicallypartioned_trainstate(abstract_sharded_state)
363362

364363
# Initialization
365364
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
366365
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
367-
return unboxed_sharded_abstract_state, state_mesh_annotations, state_mesh_shardings
366+
return abstract_sharded_state, state_mesh_annotations, state_mesh_shardings
368367

369368

370369
def setup_initial_state(

0 commit comments

Comments
 (0)