4040from jax ._src .lax import control_flow
4141from jax ._src .lax import eigh as lax_eigh
4242from jax ._src .lax import lax as lax_internal
43+ from jax ._src .partition_spec import PartitionSpec as P
4344from jax ._src .lax import svd as lax_svd
4445from jax ._src .lax .lax import (
4546 standard_primitive , standard_unop , naryop_dtype_rule , _float , _complex ,
@@ -960,9 +961,20 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
960961
961962 batch_dims = operand .shape [:- 2 ]
962963 n = operand .shape [- 1 ]
964+ if config .sharding_in_types .value :
965+ batch_s = operand .sharding .spec [:- 2 ]
966+ ns = operand .sharding .spec [- 1 ]
967+ if ns is not None :
968+ raise ValueError (f'n should be unsharded. Got n: { ns } '
969+ ' specs. Try marking their specs as None.' )
970+ w_s = operand .sharding .with_spec (P (* batch_s + (ns ,)))
971+ v_s = operand .sharding .with_spec (P (* batch_s + (ns , ns )))
972+ else :
973+ w_s , v_s = None , None
963974 w = operand .update (shape = batch_dims + (n ,),
964- dtype = lax_internal ._complex_basetype (operand .dtype ))
965- v = operand .update (shape = batch_dims + (n , n ))
975+ dtype = lax_internal ._complex_basetype (operand .dtype ),
976+ sharding = w_s )
977+ v = operand .update (shape = batch_dims + (n , n ), sharding = v_s )
966978 else :
967979 w , v = operand , operand
968980 return w , v
@@ -1029,16 +1041,23 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
10291041
10301042 batch_dims = operand .shape [:- 2 ]
10311043 n = operand .shape [- 1 ]
1032- d = (
1033- n
1034- if subset_by_index is None
1035- else subset_by_index [1 ] - subset_by_index [0 ]
1036- )
1037- v = operand .update (shape = batch_dims + (n , d ))
1044+ d = (n if subset_by_index is None else
1045+ subset_by_index [1 ] - subset_by_index [0 ])
1046+ if config .sharding_in_types .value :
1047+ batch_s = operand .sharding .spec [:- 2 ]
1048+ ns , ds = operand .sharding .spec [- 1 ], None
1049+ if ns is not None :
1050+ raise ValueError (f'n should be unsharded. Got n: { ns } specs. Try '
1051+ 'marking their specs as None.' )
1052+ v_s = operand .sharding .with_spec (P (* batch_s + (ns , ds )))
1053+ w_s = operand .sharding .with_spec (P (* batch_s + (ds ,)))
1054+ else :
1055+ v_s , w_s = None , None
1056+ v = operand .update (shape = batch_dims + (n , d ), sharding = v_s )
10381057 w = operand .update (
10391058 shape = batch_dims + (d ,),
10401059 dtype = lax_internal ._complex_basetype (operand .dtype ),
1041- )
1060+ sharding = w_s )
10421061 else :
10431062 v , w = operand , operand
10441063 return v , w
@@ -1249,6 +1268,24 @@ def _triangular_solve_shape_rule(a, b, *, left_side=False, **unused_kwargs):
12491268 raise TypeError (msg .format (a .shape , b .shape ))
12501269 return b .shape
12511270
1271+ def _triangular_solve_sharding_rule (a , b , * , left_side = False , ** unused_kwargs ):
1272+ a_spec , b_spec = a .sharding .spec , b .sharding .spec
1273+ if a_spec [- 1 ] != a_spec [- 2 ]:
1274+ raise TypeError (
1275+ "triangular_solve requires the last two dimensions of a to be equal "
1276+ f"in sharding, got a_spec of { a_spec } ." )
1277+ if a_spec [:- 2 ] != b_spec [:- 2 ]:
1278+ raise TypeError (
1279+ "triangular_solve requires both arguments to have the same number "
1280+ f"of dimensions and equal batch shardings, got { a_spec } and { b_spec } ." )
1281+ common_dim = - 2 if left_side else - 1
1282+ if a_spec [- 1 ] != b_spec [common_dim ]:
1283+ raise TypeError (
1284+ "Incompatible shardings for arguments to triangular_solve:"
1285+ f" { a_spec } and { b_spec } ." )
1286+ return b .sharding
1287+
1288+
12521289def _triangular_solve_jvp_rule_a (
12531290 g_a , ans , a , b , * , left_side , lower , transpose_a , conjugate_a ,
12541291 unit_diagonal ):
@@ -1328,7 +1365,7 @@ def _triangular_solve_batching_rule(batched_args, batch_dims, *, left_side,
13281365
13291366triangular_solve_p = standard_primitive (
13301367 _triangular_solve_shape_rule , _triangular_solve_dtype_rule ,
1331- 'triangular_solve' )
1368+ 'triangular_solve' , sharding_rule = _triangular_solve_sharding_rule )
13321369ad .defjvp2 (triangular_solve_p ,
13331370 _triangular_solve_jvp_rule_a ,
13341371 lambda g_b , _ , a , b , ** kws : triangular_solve (a , g_b , ** kws ))
@@ -1346,10 +1383,13 @@ def _triangular_solve_lowering(
13461383 transpose = "NO_TRANSPOSE"
13471384 else :
13481385 transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
1349- return [hlo .triangular_solve (
1350- a , b , ir .BoolAttr .get (left_side ),
1351- ir .BoolAttr .get (lower ), ir .BoolAttr .get (unit_diagonal ),
1352- hlo .TransposeAttr .get (transpose ))]
1386+ out = hlo .triangular_solve (a , b , ir .BoolAttr .get (left_side ),
1387+ ir .BoolAttr .get (lower ),
1388+ ir .BoolAttr .get (unit_diagonal ),
1389+ hlo .TransposeAttr .get (transpose ))
1390+ if config .sharding_in_types .value :
1391+ return [mlir .lower_sharding_under_shit (ctx , out , out_aval )]
1392+ return [out ]
13531393
13541394
13551395def _triangular_solve_cpu_lower (
@@ -1802,7 +1842,17 @@ def _geqrf_abstract_eval(operand):
18021842 if operand .ndim < 2 :
18031843 raise ValueError ("Argument to QR decomposition must have ndims >= 2" )
18041844 * batch_dims , m , n = operand .shape
1805- taus = operand .update (shape = (* batch_dims , core .min_dim (m , n )))
1845+ if config .sharding_in_types .value :
1846+ spec = operand .sharding .spec
1847+ batch_s , ms , ns = spec [:- 2 ], spec [- 2 ], spec [- 1 ]
1848+ if ms is not None or ns is not None :
1849+ raise ValueError (f'm and n should be unsharded. Got m: { ms } and n: { ns } '
1850+ ' specs. Try marking their specs as None.' )
1851+ taus_s = operand .sharding .with_spec (P (* (* batch_s , None )))
1852+ else :
1853+ taus_s = None
1854+ taus = operand .update (shape = (* batch_dims , core .min_dim (m , n )),
1855+ sharding = taus_s )
18061856 return operand , taus
18071857
18081858def _geqrf_batching_rule (batched_args , batch_dims ):
@@ -2024,13 +2074,23 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices):
20242074 raise ValueError ("Argument to QR decomposition must have ndims >= 2" )
20252075 * batch_dims , m , n = operand .shape
20262076 k = m if full_matrices else core .min_dim (m , n )
2027- q = operand .update (shape = (* batch_dims , m , k ))
2028- r = operand .update (shape = (* batch_dims , k , n ))
2029- p = operand .update (shape = (* batch_dims , n ), dtype = np .dtype (np .int32 ))
2077+ if config .sharding_in_types .value :
2078+ * batch_s , ms , ns = operand .sharding .spec
2079+ ks = None
2080+ if ms is not None or ns is not None :
2081+ raise ValueError (f'm and n should be unsharded. Got m: { ms } and n: { ns } '
2082+ ' specs. Try marking their specs as None.' )
2083+ q_s = operand .sharding .with_spec (P (* (* batch_s , ms , ks )))
2084+ r_s = operand .sharding .with_spec (P (* (* batch_s , ks , ns )))
2085+ p_s = operand .sharding .with_spec (P (* (* batch_s , ns )))
2086+ else :
2087+ q_s , r_s , p_s = None , None , None
2088+ q = operand .update (shape = (* batch_dims , m , k ), sharding = q_s )
2089+ r = operand .update (shape = (* batch_dims , k , n ), sharding = r_s )
2090+ p = operand .update (shape = (* batch_dims , n ), dtype = np .dtype (np .int32 ),
2091+ sharding = p_s )
20302092 else :
2031- q = operand
2032- r = operand
2033- p = operand
2093+ q , r , p = operand , operand , operand
20342094 return (q , r , p ) if pivoting else (q , r )
20352095
20362096def qr_jvp_rule (primals , tangents , * , pivoting , full_matrices ):
@@ -2136,13 +2196,32 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
21362196 raise ValueError ("full_matrices and subset_by_index cannot both be set" )
21372197 rank = min (rank , subset_by_index [1 ] - subset_by_index [0 ])
21382198
2199+ if config .sharding_in_types .value :
2200+ batch_s = operand .sharding .spec [:- 2 ]
2201+ ms = operand .sharding .spec [- 2 ]
2202+ ns = operand .sharding .spec [- 1 ]
2203+ if ms is not None or ns is not None :
2204+ raise ValueError (f'm and n should be unsharded. Got m: { ms } and n: { ns } '
2205+ ' specs. Try marking their specs as None.' )
2206+ rank_s = None
2207+ s_sharding = operand .sharding .with_spec (P (* batch_s + (rank_s ,)))
2208+ u_sharding = operand .sharding .with_spec (
2209+ P (* batch_s + (ms , ms if full_matrices else rank_s )))
2210+ vt_sharding = operand .sharding .with_spec (
2211+ P (* batch_s + (ns if full_matrices else rank_s , ns )))
2212+ else :
2213+ s_sharding , u_sharding , vt_sharding = None , None , None
2214+
21392215 s = operand .update (
21402216 shape = batch_dims + (rank ,),
21412217 dtype = lax_internal ._complex_basetype (operand .dtype ),
2218+ sharding = s_sharding
21422219 )
21432220 if compute_uv :
2144- u = operand .update (shape = batch_dims + (m , m if full_matrices else rank ))
2145- vt = operand .update (shape = batch_dims + (n if full_matrices else rank , n ))
2221+ u = operand .update (shape = batch_dims + (m , m if full_matrices else rank ),
2222+ sharding = u_sharding )
2223+ vt = operand .update (shape = batch_dims + (n if full_matrices else rank , n ),
2224+ sharding = vt_sharding )
21462225 return s , u , vt
21472226 else :
21482227 return s ,
0 commit comments