The implications of check_vma
in the shard_map
API
#31449
-
I think the documentation of Say I have an MLP with two linear layers. Here is the config: # num_devices = 8
devices = np.array(jax.devices())
mesh_axis_names = ("feats",)
mesh = Mesh(devices, axis_names=mesh_axis_names)
# input shape: [batch_size, features], labels shape: [batch_size,]
# params = MLP(
# fc1=Linear(
# in_features=784,
# out_features=128,
# weight=float32[784,128],
# bias=float32[128],
# use_bias=True,
# ),
# fc2=Linear(
# in_features=128,
# out_features=10,
# weight=float32[128,10],
# bias=None,
# use_bias=False,
# )
# )
# forward function
@partial(jax.shard_map, mesh=mesh,
in_specs=(P(None, 'feats'), param_specs),
out_specs=P(),
check_vma=False
)
def forward_tp(inputs, params):
# inputs shape: (256, 98)
out = jnp.dot(inputs, params.fc1.weight)
# our first layer is sharded for TP
out = jax.lax.psum_scatter(out, 'feats', scatter_dimension=1, tiled=True)
if params.fc1.bias is not None:
out = out + params.fc1.bias
# second layer is replicated, so we need to gather all features
out = jax.lax.all_gather(out, 'feats', axis=1, tiled=True)
out = jnp.dot(out, params.fc2.weight)
return out Without |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 4 replies
-
What is the error? If I guess what it is, you will need to use jax.lax.all_gather_invariant for out_specs to be P(). But usually no, you shouldn't need to disable check_vma. |
Beta Was this translation helpful? Give feedback.
Use jax.lax.all_gather_invariant if you want out_specs to be P()
If you want to use all_gather, then your out_specs should be P(None, feats)