Skip to content
Discussion options

You must be logged in to vote

Actually we need to do some more work to make this possible. But this is roughly how you can do it (note this code raises an error right now which I'll fix):

In [19]: mesh = jax.make_mesh((4, 2), ('x', 'y'), axis_types=(AxisType.Explicit, AxisType.Explicit))

In [20]: a = jax.device_put(np.arange(8).reshape(4, 2), NamedSharding(mesh, P('x', 'y')))

In [21]: b = jax.device_put(np.arange(8).reshape(4, 2), NamedSharding(mesh, P('x', 'y')))

In [22]: @jax.jit
    ...: def f(a, b):
    ...:     c = a * b
    ...:     print(f'{jax.typeof(c)}=')
    ...:     with use_abstract_mesh(AbstractMesh((4, 2), ('x1', 'y1'), axis_types=(AxisType.Explicit, AxisType.Explicit))):
    ...:         c = mesh_ca…

Replies: 1 comment 11 replies

Comment options

You must be logged in to vote
11 replies
@yashk2810
Comment options

Answer selected by jfc4050
@jfc4050
Comment options

@yashk2810
Comment options

@jfc4050
Comment options

@yashk2810
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Ideas
Labels
None yet
2 participants