Skip to content

Commit ae705fe

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Add support for svd_p
PiperOrigin-RevId: 720409750
1 parent 24987a9 commit ae705fe

File tree

5 files changed

+156
-36
lines changed

5 files changed

+156
-36
lines changed

jax/_src/core.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,14 +1731,20 @@ def _invalid_shape_error(shape: Shape, context: str=""):
17311731

17321732
return TypeError(msg)
17331733

1734+
def _make_lengths_same(sharding, ndim):
1735+
if ndim > len(sharding.spec):
1736+
return sharding.with_spec(sharding.spec._normalized_spec(ndim))
1737+
if ndim < len(sharding.spec):
1738+
return sharding.with_spec(sharding.spec[:ndim])
1739+
assert False, "unreachable"
1740+
1741+
17341742
# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
17351743
# Collective too.
17361744
def modify_spec_for_auto_manual(spec, mesh) -> P:
1737-
if all(s is None for s in spec):
1738-
return spec
17391745
new_spec = [] # type: ignore
17401746
for s in spec:
1741-
if s is None:
1747+
if not s:
17421748
new_spec.append(s)
17431749
else:
17441750
temp_s = s[0] if isinstance(s, tuple) else s
@@ -1748,22 +1754,29 @@ def modify_spec_for_auto_manual(spec, mesh) -> P:
17481754
else s)
17491755
return P(*new_spec)
17501756

1751-
def _maybe_modify_sharding(sharding):
1757+
def _maybe_modify_sharding(sharding, ndim):
17521758
if sharding.mesh._are_all_axes_explicit:
1753-
return sharding
1754-
new_spec = modify_spec_for_auto_manual(sharding.spec, sharding.mesh)
1755-
return sharding.with_spec(new_spec)
1759+
out = sharding
1760+
elif all(s is None for s in sharding.spec):
1761+
out = sharding
1762+
else:
1763+
out = sharding.with_spec(modify_spec_for_auto_manual(
1764+
sharding.spec, sharding.mesh))
1765+
if (len(out.spec) != ndim and
1766+
(out.mesh._are_all_axes_auto or out.mesh._are_all_axes_manual)):
1767+
out = _make_lengths_same(out, ndim)
1768+
return out
17561769

17571770

17581771
def get_sharding(sharding, ndim):
17591772
from jax._src.sharding_impls import NamedSharding # type: ignore
17601773

17611774
if sharding is not None:
1762-
if len(sharding.spec) != ndim:
1775+
out_s = _maybe_modify_sharding(sharding, ndim)
1776+
if len(out_s.spec) != ndim:
17631777
raise ValueError(
17641778
"Length of sharding.spec must be equal to aval's ndim. Got"
1765-
f" sharding.spec {sharding.spec} and aval.ndim {ndim}")
1766-
out_s = _maybe_modify_sharding(sharding)
1779+
f" sharding.spec {out_s.spec} and aval.ndim {ndim}")
17671780
else:
17681781
context_mesh = mesh_lib.get_abstract_mesh()
17691782
if not context_mesh:

jax/_src/lax/lax.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5165,18 +5165,28 @@ def _rev_shape_rule(operand, *, dimensions):
51655165
raise TypeError(msg.format(dimensions, operand.ndim))
51665166
return operand.shape
51675167

5168+
def _rev_sharding_rule(operand, *, dimensions):
5169+
# TODO(yashkatariya): Will lead to data movement. Maybe just error out and
5170+
# require the operand to be unsharded?
5171+
return operand.sharding
5172+
51685173
def _rev_batch_rule(batched_args, batch_dims, *, dimensions):
51695174
operand, = batched_args
51705175
bdim, = batch_dims
51715176
new_dimensions = [i + 1 if i >= bdim else i for i in dimensions]
51725177
return rev(operand, new_dimensions), bdim
51735178

5174-
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev')
5179+
rev_p = standard_primitive(_rev_shape_rule, _input_dtype, 'rev',
5180+
sharding_rule=_rev_sharding_rule)
51755181
ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
51765182
batching.primitive_batchers[rev_p] = _rev_batch_rule
51775183

51785184
def _rev_lower(ctx, x, *, dimensions):
5179-
return [hlo.reverse(x, mlir.dense_int_array(dimensions))]
5185+
aval_out, = ctx.avals_out
5186+
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
5187+
if config.sharding_in_types.value:
5188+
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
5189+
return [out]
51805190
mlir.register_lowering(rev_p, _rev_lower)
51815191

51825192

@@ -5932,7 +5942,10 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
59325942
mlir.flatten_ir_values(operands),
59335943
dimension=mlir.i64_attr(dimension),
59345944
is_stable=ir.BoolAttr.get(is_stable))
5935-
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
5945+
scalar_s = (lambda a: a.sharding.with_spec(P())
5946+
if config.sharding_in_types.value else lambda _: None)
5947+
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
5948+
for aval in ctx.avals_in]
59365949
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
59375950
comparator = sort.comparator.blocks.append(
59385951
*util.flatten(zip(scalar_types, scalar_types)))

jax/_src/lax/linalg.py

Lines changed: 102 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from jax._src.lax import control_flow
4141
from jax._src.lax import eigh as lax_eigh
4242
from jax._src.lax import lax as lax_internal
43+
from jax._src.partition_spec import PartitionSpec as P
4344
from jax._src.lax import svd as lax_svd
4445
from jax._src.lax.lax import (
4546
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
@@ -960,9 +961,20 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
960961

961962
batch_dims = operand.shape[:-2]
962963
n = operand.shape[-1]
964+
if config.sharding_in_types.value:
965+
batch_s = operand.sharding.spec[:-2]
966+
ns = operand.sharding.spec[-1]
967+
if ns is not None:
968+
raise ValueError(f'n should be unsharded. Got n: {ns}'
969+
' specs. Try marking their specs as None.')
970+
w_s = operand.sharding.with_spec(P(*batch_s + (ns,)))
971+
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns)))
972+
else:
973+
w_s, v_s = None, None
963974
w = operand.update(shape=batch_dims + (n,),
964-
dtype=lax_internal._complex_basetype(operand.dtype))
965-
v = operand.update(shape=batch_dims + (n, n))
975+
dtype=lax_internal._complex_basetype(operand.dtype),
976+
sharding=w_s)
977+
v = operand.update(shape=batch_dims + (n, n), sharding=v_s)
966978
else:
967979
w, v = operand, operand
968980
return w, v
@@ -1029,16 +1041,23 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
10291041

10301042
batch_dims = operand.shape[:-2]
10311043
n = operand.shape[-1]
1032-
d = (
1033-
n
1034-
if subset_by_index is None
1035-
else subset_by_index[1] - subset_by_index[0]
1036-
)
1037-
v = operand.update(shape=batch_dims + (n, d))
1044+
d = (n if subset_by_index is None else
1045+
subset_by_index[1] - subset_by_index[0])
1046+
if config.sharding_in_types.value:
1047+
batch_s = operand.sharding.spec[:-2]
1048+
ns, ds = operand.sharding.spec[-1], None
1049+
if ns is not None:
1050+
raise ValueError(f'n should be unsharded. Got n: {ns} specs. Try '
1051+
'marking their specs as None.')
1052+
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds)))
1053+
w_s = operand.sharding.with_spec(P(*batch_s + (ds,)))
1054+
else:
1055+
v_s, w_s = None, None
1056+
v = operand.update(shape=batch_dims + (n, d), sharding=v_s)
10381057
w = operand.update(
10391058
shape=batch_dims + (d,),
10401059
dtype=lax_internal._complex_basetype(operand.dtype),
1041-
)
1060+
sharding=w_s)
10421061
else:
10431062
v, w = operand, operand
10441063
return v, w
@@ -1249,6 +1268,24 @@ def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
12491268
raise TypeError(msg.format(a.shape, b.shape))
12501269
return b.shape
12511270

1271+
def _triangular_solve_sharding_rule(a, b, *, left_side=False, **unused_kwargs):
1272+
a_spec, b_spec = a.sharding.spec, b.sharding.spec
1273+
if a_spec[-1] != a_spec[-2]:
1274+
raise TypeError(
1275+
"triangular_solve requires the last two dimensions of a to be equal "
1276+
f"in sharding, got a_spec of {a_spec}.")
1277+
if a_spec[:-2] != b_spec[:-2]:
1278+
raise TypeError(
1279+
"triangular_solve requires both arguments to have the same number "
1280+
f"of dimensions and equal batch shardings, got {a_spec} and {b_spec}.")
1281+
common_dim = -2 if left_side else -1
1282+
if a_spec[-1] != b_spec[common_dim]:
1283+
raise TypeError(
1284+
"Incompatible shardings for arguments to triangular_solve:"
1285+
f" {a_spec} and {b_spec}.")
1286+
return b.sharding
1287+
1288+
12521289
def _triangular_solve_jvp_rule_a(
12531290
g_a, ans, a, b, *, left_side, lower, transpose_a, conjugate_a,
12541291
unit_diagonal):
@@ -1328,7 +1365,7 @@ def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
13281365

13291366
triangular_solve_p = standard_primitive(
13301367
_triangular_solve_shape_rule, _triangular_solve_dtype_rule,
1331-
'triangular_solve')
1368+
'triangular_solve', sharding_rule=_triangular_solve_sharding_rule)
13321369
ad.defjvp2(triangular_solve_p,
13331370
_triangular_solve_jvp_rule_a,
13341371
lambda g_b, _, a, b, **kws: triangular_solve(a, g_b, **kws))
@@ -1346,10 +1383,13 @@ def _triangular_solve_lowering(
13461383
transpose = "NO_TRANSPOSE"
13471384
else:
13481385
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
1349-
return [hlo.triangular_solve(
1350-
a, b, ir.BoolAttr.get(left_side),
1351-
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
1352-
hlo.TransposeAttr.get(transpose))]
1386+
out = hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side),
1387+
ir.BoolAttr.get(lower),
1388+
ir.BoolAttr.get(unit_diagonal),
1389+
hlo.TransposeAttr.get(transpose))
1390+
if config.sharding_in_types.value:
1391+
return [mlir.lower_sharding_under_shit(ctx, out, out_aval)]
1392+
return [out]
13531393

13541394

13551395
def _triangular_solve_cpu_lower(
@@ -1802,7 +1842,17 @@ def _geqrf_abstract_eval(operand):
18021842
if operand.ndim < 2:
18031843
raise ValueError("Argument to QR decomposition must have ndims >= 2")
18041844
*batch_dims, m, n = operand.shape
1805-
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)))
1845+
if config.sharding_in_types.value:
1846+
spec = operand.sharding.spec
1847+
batch_s, ms, ns = spec[:-2], spec[-2], spec[-1]
1848+
if ms is not None or ns is not None:
1849+
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
1850+
' specs. Try marking their specs as None.')
1851+
taus_s = operand.sharding.with_spec(P(*(*batch_s, None)))
1852+
else:
1853+
taus_s = None
1854+
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)),
1855+
sharding=taus_s)
18061856
return operand, taus
18071857

18081858
def _geqrf_batching_rule(batched_args, batch_dims):
@@ -2024,13 +2074,23 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
20242074
raise ValueError("Argument to QR decomposition must have ndims >= 2")
20252075
*batch_dims, m, n = operand.shape
20262076
k = m if full_matrices else core.min_dim(m, n)
2027-
q = operand.update(shape=(*batch_dims, m, k))
2028-
r = operand.update(shape=(*batch_dims, k, n))
2029-
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32))
2077+
if config.sharding_in_types.value:
2078+
*batch_s, ms, ns = operand.sharding.spec
2079+
ks = None
2080+
if ms is not None or ns is not None:
2081+
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
2082+
' specs. Try marking their specs as None.')
2083+
q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks)))
2084+
r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns)))
2085+
p_s = operand.sharding.with_spec(P(*(*batch_s, ns)))
2086+
else:
2087+
q_s, r_s, p_s = None, None, None
2088+
q = operand.update(shape=(*batch_dims, m, k), sharding=q_s)
2089+
r = operand.update(shape=(*batch_dims, k, n), sharding=r_s)
2090+
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32),
2091+
sharding=p_s)
20302092
else:
2031-
q = operand
2032-
r = operand
2033-
p = operand
2093+
q, r, p = operand, operand, operand
20342094
return (q, r, p) if pivoting else (q, r)
20352095

20362096
def qr_jvp_rule(primals, tangents, *, pivoting, full_matrices):
@@ -2136,13 +2196,32 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
21362196
raise ValueError("full_matrices and subset_by_index cannot both be set")
21372197
rank = min(rank, subset_by_index[1] - subset_by_index[0])
21382198

2199+
if config.sharding_in_types.value:
2200+
batch_s = operand.sharding.spec[:-2]
2201+
ms = operand.sharding.spec[-2]
2202+
ns = operand.sharding.spec[-1]
2203+
if ms is not None or ns is not None:
2204+
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
2205+
' specs. Try marking their specs as None.')
2206+
rank_s = None
2207+
s_sharding = operand.sharding.with_spec(P(*batch_s + (rank_s,)))
2208+
u_sharding = operand.sharding.with_spec(
2209+
P(*batch_s + (ms, ms if full_matrices else rank_s)))
2210+
vt_sharding = operand.sharding.with_spec(
2211+
P(*batch_s + (ns if full_matrices else rank_s, ns)))
2212+
else:
2213+
s_sharding, u_sharding, vt_sharding = None, None, None
2214+
21392215
s = operand.update(
21402216
shape=batch_dims + (rank,),
21412217
dtype=lax_internal._complex_basetype(operand.dtype),
2218+
sharding=s_sharding
21422219
)
21432220
if compute_uv:
2144-
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank))
2145-
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n))
2221+
u = operand.update(shape=batch_dims + (m, m if full_matrices else rank),
2222+
sharding=u_sharding)
2223+
vt = operand.update(shape=batch_dims + (n if full_matrices else rank, n),
2224+
sharding=vt_sharding)
21462225
return s, u, vt
21472226
else:
21482227
return s,

jax/_src/lax/slicing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,10 @@ def _gather_sharding_rule(operand, indices, *, dimension_numbers,
18861886
cur_mesh = mesh_lib.get_abstract_mesh()
18871887
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore
18881888
return None
1889+
if (cur_mesh._are_all_axes_explicit and # type: ignore
1890+
all(s is None for s in operand.sharding.spec) and
1891+
all(s is None for s in indices.sharding.spec)):
1892+
return None
18891893
raise GatherShardingError(
18901894
"Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for"
18911895
" the gather indexing.")

tests/pjit_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6368,6 +6368,17 @@ def f(x):
63686368
self.assertTupleEqual(out2.sharding._device_assignment,
63696369
tuple(mesh2.devices.flat))
63706370

6371+
@jtu.with_user_mesh((2, 1), ('x', 'y'))
6372+
def test_svd(self, mesh):
6373+
np_inp = np.zeros([128, 128])
6374+
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, None)))
6375+
6376+
@jax.jit
6377+
def f(x):
6378+
return jnp.linalg.norm(x, 2)
6379+
6380+
f(arr) # doesn't crash
6381+
63716382

63726383
@jtu.pytest_mark_if_available('multiaccelerator')
63736384
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)