Skip to content

Commit 0b3f0e1

Browse files
Reverts ebb75db
PiperOrigin-RevId: 688477769
1 parent 84a303f commit 0b3f0e1

File tree

8 files changed

+34
-142
lines changed

8 files changed

+34
-142
lines changed

jax/_src/lax/lax.py

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,8 +1040,7 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None,
10401040

10411041
def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers,
10421042
precision: PrecisionLike = None,
1043-
preferred_element_type: DTypeLike | None = None,
1044-
out_type=None) -> Array:
1043+
preferred_element_type: DTypeLike | None = None) -> Array:
10451044
"""General dot product/contraction operator.
10461045
10471046
Wraps XLA's `DotGeneral
@@ -1087,10 +1086,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
10871086
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
10881087
non-contracting/non-batch dimensions.
10891088
"""
1090-
if out_type is not None and not isinstance(out_type, NamedSharding):
1091-
raise NotImplementedError(
1092-
'`out_type` argument of `dot_general` only supports NamedSharding '
1093-
'instances. Please file a bug if this is not enough for your use case.')
10941089
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
10951090
cdims = (api_util._ensure_index_tuple(lhs_contract),
10961091
api_util._ensure_index_tuple(rhs_contract))
@@ -1102,8 +1097,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
11021097
return dot_general_p.bind(lhs, rhs,
11031098
dimension_numbers=(cdims, bdims),
11041099
precision=canonicalize_precision(precision),
1105-
preferred_element_type=preferred_element_type,
1106-
out_type=out_type)
1100+
preferred_element_type=preferred_element_type)
11071101

11081102

11091103
def ragged_dot(
@@ -3008,11 +3002,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
30083002
not dtypes.issubdtype(new_dtype, np.complexfloating)):
30093003
operand = hlo.real(operand)
30103004
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
3011-
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
3012-
if config.sharding_in_types.value:
3013-
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
3014-
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
3015-
return [out]
3005+
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
30163006

30173007
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
30183008

@@ -3174,8 +3164,7 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type):
31743164

31753165

31763166
def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision,
3177-
preferred_element_type: DTypeLike | None,
3178-
out_type):
3167+
preferred_element_type: DTypeLike | None):
31793168
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
31803169
if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim))
31813170
for d in (lhs_contracting, lhs_batch)):
@@ -3252,28 +3241,24 @@ def _check_specs_match(lhs_spec, rhs_spec, msg):
32523241
raise TypeError(msg)
32533242

32543243
def _dot_general_sharding_rule(lhs, rhs, *, dimension_numbers, precision,
3255-
preferred_element_type: DTypeLike | None,
3256-
out_type):
3244+
preferred_element_type: DTypeLike | None):
32573245
if lhs.sharding.mesh != rhs.sharding.mesh:
32583246
raise ValueError(
32593247
'Mesh of both lhs and rhs should match. Got lhs:'
32603248
f' {lhs.sharding.mesh} and rhs: {rhs.sharding.mesh}')
32613249

3262-
if out_type is not None:
3263-
return out_type
3264-
32653250
(lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers
32663251
lhs_batch_spec = tuple(lhs.sharding.spec[i] for i in lhs_batch)
32673252
rhs_batch_spec = tuple(rhs.sharding.spec[i] for i in rhs_batch)
32683253
msg = ("dot_general requires lhs batch dimensions and rhs batch dimensions "
3269-
f"to have the consistent sharding, got {lhs_batch_spec} and "
3270-
f"{rhs_batch_spec}.")
3254+
f"to have the consistent sharding, got {lhs_batch_spec} and "
3255+
f"{rhs_batch_spec}.")
32713256
_check_specs_match(lhs_batch_spec, rhs_batch_spec, msg)
32723257

32733258
lhs_contracting_spec = tuple(lhs.sharding.spec[i] for i in lhs_contracting)
32743259
rhs_contracting_spec = tuple(rhs.sharding.spec[i] for i in rhs_contracting)
32753260
msg = ("dot_general requires contracting dimensions to have consistent "
3276-
f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.")
3261+
f"sharding, got {lhs_contracting_spec} and {rhs_contracting_spec}.")
32773262
_check_specs_match(lhs_contracting_spec, rhs_contracting_spec, msg)
32783263

32793264
return _dot_general_sharding_computation(
@@ -3295,8 +3280,7 @@ def tuple_delete(tup, idx):
32953280

32963281

32973282
def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision,
3298-
preferred_element_type: DTypeLike | None,
3299-
out_type):
3283+
preferred_element_type: DTypeLike | None):
33003284
del dimension_numbers # unused
33013285
# We're mostly matching XLA's logic here, namely in shape_inference.cc and
33023286
# primitive_util.h's HigherPrecisionType, e.g.
@@ -3343,7 +3327,7 @@ def _maybe_upcast(result_dtype, preferred_element_type, check_bit_width):
33433327

33443328
def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33453329
preferred_element_type: DTypeLike | None,
3346-
out_type, swap_ans=False):
3330+
swap_ans=False):
33473331
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
33483332
x_ndim = x.aval.ndim
33493333
x_kept = remaining(range(x_ndim), x_contract, x_batch)
@@ -3363,14 +3347,12 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
33633347
return x_bar
33643348

33653349
def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision,
3366-
preferred_element_type: DTypeLike | None,
3367-
out_type):
3350+
preferred_element_type: DTypeLike | None):
33683351
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
33693352
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
33703353
y_bar = _dot_general_transpose_lhs(
33713354
g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision,
3372-
preferred_element_type=preferred_element_type, out_type=out_type,
3373-
swap_ans=True)
3355+
preferred_element_type=preferred_element_type, swap_ans=True)
33743356
if y_bar.dtype != y.aval.dtype:
33753357
y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type)
33763358
return y_bar
@@ -3384,7 +3366,6 @@ def _dot_batch_rule(
33843366
batch_dims,
33853367
*,
33863368
dimension_numbers,
3387-
out_type,
33883369
precision,
33893370
preferred_element_type: DTypeLike | None,
33903371
**_,
@@ -3414,16 +3395,12 @@ def _dot_batch_rule(
34143395
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
34153396
else:
34163397
rhs_shape = np.shape(rhs)
3417-
if out_type is not None:
3418-
raise NotImplementedError("vmap with out_type is not supported. "
3419-
"Please open an issue.")
34203398
batched_out = invoke_prim(
34213399
lhs,
34223400
rhs,
34233401
new_dimension_numbers,
34243402
precision=precision,
34253403
preferred_element_type=preferred_element_type,
3426-
out_type=out_type,
34273404
)
34283405
result_batch_dim = batching.shape_as_bdim(
34293406
result_stack_dim,
@@ -3593,7 +3570,7 @@ def dot_algorithm_attr(precision: CanonicalPrecision, lhs_dtype: DTypeLike,
35933570

35943571
def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
35953572
precision, preferred_element_type: np.dtype | None,
3596-
out_type, platform: str = "default"):
3573+
platform: str = "default"):
35973574
def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes):
35983575
fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2,
35993576
dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz)
@@ -3681,8 +3658,6 @@ def maybe_convert_dtype(operand, operand_aval, target_dtype):
36813658
**algorithm_kwarg,
36823659
)
36833660
if config.sharding_in_types.value:
3684-
if out_type is not None:
3685-
assert aval_out.sharding == out_type
36863661
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
36873662
result = mlir.wrap_with_sharding_op(ctx, result, aval_out, out_sp)
36883663
if accumulation_aval.dtype != aval_out.dtype:
@@ -3736,15 +3711,12 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S
37363711
return (m, n)
37373712

37383713
def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array,
3739-
precision, preferred_element_type: DTypeLike | None,
3740-
**_) -> np.dtype:
3714+
precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype:
37413715
if not dtypes.issubdtype(group_sizes.dtype, np.integer):
37423716
raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.")
37433717
# defer the output dtype to dot_general, which is part of the _ragged_dot_impl.
3744-
return _dot_general_dtype_rule(
3745-
lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
3746-
precision=precision, preferred_element_type=preferred_element_type,
3747-
out_type=None)
3718+
return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
3719+
precision=precision, preferred_element_type=preferred_element_type)
37483720

37493721

37503722
def _ragged_dot_jvp_rule(
@@ -3883,7 +3855,6 @@ def _ragged_dot_batch_rule(
38833855
*,
38843856
precision,
38853857
preferred_element_type: DTypeLike | None,
3886-
out_type,
38873858
**_,
38883859
):
38893860
invoke = functools.partial(_ragged_dot_invoke_prim, batched_args[2])
@@ -3897,7 +3868,6 @@ def _ragged_dot_batch_rule(
38973868
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
38983869
precision=precision,
38993870
preferred_element_type=preferred_element_type,
3900-
out_type=out_type,
39013871
)
39023872

39033873

jax/_src/numpy/lax_numpy.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@
6767
DType, DTypeLike, DeprecatedArg, DimSize, DuckTypedArray, Shape, StaticScalar,
6868
)
6969
from jax._src.util import (
70-
NumpyComplexWarning, canonicalize_axis as _canonicalize_axis,
71-
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
72-
from jax.sharding import (Sharding, SingleDeviceSharding, NamedSharding,
73-
PartitionSpec as P)
70+
NumpyComplexWarning,
71+
canonicalize_axis as _canonicalize_axis,
72+
ceil_of_ratio, partition_list, safe_zip, subvals,unzip2)
73+
from jax.sharding import Sharding, SingleDeviceSharding
7474
from jax.tree_util import tree_flatten, tree_leaves, tree_map
7575
import numpy as np
7676
import opt_einsum
@@ -8955,7 +8955,6 @@ def einsum(
89558955
precision: PrecisionLike = None,
89568956
preferred_element_type: DTypeLike | None = None,
89578957
_dot_general: Callable[..., Array] = lax.dot_general,
8958-
out_type=None,
89598958
) -> Array: ...
89608959

89618960
@overload
@@ -8968,7 +8967,6 @@ def einsum(
89688967
precision: PrecisionLike = None,
89698968
preferred_element_type: DTypeLike | None = None,
89708969
_dot_general: Callable[..., Array] = lax.dot_general,
8971-
out_type=None,
89728970
) -> Array: ...
89738971

89748972
def einsum(
@@ -8979,7 +8977,6 @@ def einsum(
89798977
precision: PrecisionLike = None,
89808978
preferred_element_type: DTypeLike | None = None,
89818979
_dot_general: Callable[..., Array] = lax.dot_general,
8982-
out_type=None,
89838980
) -> Array:
89848981
"""Einstein summation
89858982
@@ -9211,11 +9208,11 @@ def einsum(
92119208

92129209
contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions)
92139210

9214-
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4, 5), inline=True)
9211+
einsum = jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True)
92159212
if spec is not None:
92169213
einsum = jax.named_call(einsum, name=spec)
92179214
return einsum(operands, contractions, precision,
9218-
preferred_element_type, _dot_general, out_type)
9215+
preferred_element_type, _dot_general)
92199216

92209217

92219218
# Enable other modules to override einsum_contact_path.
@@ -9314,12 +9311,7 @@ def _einsum(
93149311
precision,
93159312
preferred_element_type,
93169313
_dot_general=lax.dot_general,
9317-
out_type=None,
93189314
):
9319-
if out_type is not None and not isinstance(out_type, NamedSharding):
9320-
raise NotImplementedError(
9321-
"`out_type` argument of `einsum` only supports NamedSharding instances."
9322-
" Please file a bug if this is not enough for your use case.")
93239315
dtypes.check_user_dtype_supported(preferred_element_type, "einsum")
93249316
operands = list(map(asarray, operands))
93259317
if preferred_element_type is None:
@@ -9442,21 +9434,12 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
94429434
if names == result_names:
94439435
dimension_numbers = ((rhs_cont, lhs_cont), (rhs_batch, lhs_batch))
94449436
operand = _dot_general(rhs, lhs, dimension_numbers, precision,
9445-
preferred_element_type=preferred_element_type,
9446-
out_type=out_type)
9437+
preferred_element_type=preferred_element_type)
94479438
else:
94489439
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
9449-
if (config.sharding_in_types.value and out_type is not None and
9450-
names != result_names):
9451-
spec = out_type.spec
9452-
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
9453-
dot_general_out_type = NamedSharding(out_type.mesh, P(*inverse_spec))
9454-
else:
9455-
dot_general_out_type = out_type # type: ignore
94569440
dimension_numbers = ((lhs_cont, rhs_cont), (lhs_batch, rhs_batch))
94579441
operand = _dot_general(lhs, rhs, dimension_numbers, precision,
9458-
preferred_element_type=preferred_element_type,
9459-
out_type=dot_general_out_type)
9442+
preferred_element_type=preferred_element_type)
94609443
else:
94619444
raise NotImplementedError # if this is actually reachable, open an issue!
94629445

@@ -9469,8 +9452,7 @@ def filter_singleton_dims(operand, names, other_shape, other_names):
94699452
operand = lax.transpose(operand, perm)
94709453
operands.append(operand) # used in next iteration
94719454

9472-
return lax_internal._convert_element_type(operands[0], preferred_element_type,
9473-
output_weak_type)
9455+
return lax_internal._convert_element_type(operands[0], preferred_element_type, output_weak_type)
94749456

94759457

94769458
@partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True)

jax/_src/pallas/triton/lowering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2089,11 +2089,10 @@ def _dot_general_lowering(
20892089
b,
20902090
*,
20912091
dimension_numbers,
2092-
out_type,
20932092
precision,
20942093
preferred_element_type,
20952094
):
2096-
del preferred_element_type, out_type # Unused.
2095+
del preferred_element_type # Unused.
20972096
((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers
20982097
assert batch_dims == ((), ())
20992098

jax/experimental/jax2tf/jax2tf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2180,7 +2180,7 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None):
21802180
tf_impl_with_avals[lax.conv_general_dilated_p] = _conv_general_dilated
21812181

21822182

2183-
def _dot_general(lhs, rhs, *, dimension_numbers, out_type,
2183+
def _dot_general(lhs, rhs, *, dimension_numbers,
21842184
precision: lax_internal.CanonicalPrecision,
21852185
preferred_element_type: DType | None,
21862186
_in_avals: Sequence[core.ShapedArray],

jax/experimental/sparse/bcoo.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -606,11 +606,8 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc
606606

607607
bcoo_dot_general_p = core.Primitive('bcoo_dot_general')
608608

609-
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
610-
dimension_numbers: DotDimensionNumbers,
611-
precision: None = None,
612-
preferred_element_type: None = None,
613-
out_type=None) -> BCOO | Array:
609+
def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers,
610+
precision: None = None, preferred_element_type: None = None) -> BCOO | Array:
614611
"""A general contraction operation.
615612
616613
Args:
@@ -628,7 +625,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *,
628625
the result will be dense, of type ndarray.
629626
"""
630627
# TODO(jakevdp) make use of these?
631-
del precision, out_type # unused
628+
del precision # unused
632629
if isinstance(lhs, BCOO) and isinstance(rhs, BCOO):
633630
shape = _dot_general_validated_shape(lhs.shape, rhs.shape,
634631
dimension_numbers)
@@ -1054,8 +1051,7 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers)
10541051
indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True)
10551052
kwds = {'dimension_numbers': dimension_numbers,
10561053
'precision': None,
1057-
'preferred_element_type': None,
1058-
'out_type': None}
1054+
'preferred_element_type': None}
10591055
A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds)
10601056
return A, B, indices
10611057

jax/experimental/sparse/bcsr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,7 @@ def _bcsr_extract_batching_rule(batched_args, batch_dims):
462462
def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
463463
dimension_numbers: DotDimensionNumbers,
464464
precision: None = None,
465-
preferred_element_type: None = None,
466-
out_type=None) -> Array:
465+
preferred_element_type: None = None) -> Array:
467466
"""A general contraction operation.
468467
469468
Args:
@@ -480,7 +479,7 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *,
480479
are sparse, the result will be sparse, of type BCSR. If either input is
481480
dense, the result will be dense, of type ndarray.
482481
"""
483-
del precision, out_type # unused
482+
del precision # unused
484483
if isinstance(rhs, (np.ndarray, jax.Array)):
485484
if isinstance(lhs, (np.ndarray, jax.Array)):
486485
return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers,

jax/experimental/sparse/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,4 @@ def _dot_general_validated_shape(
111111
rhs = core.ShapedArray(rhs_shape, np.float32)
112112
return _dot_general_shape_rule(
113113
lhs, rhs, dimension_numbers=dimension_numbers,
114-
precision=None, preferred_element_type=None, out_type=None)
114+
precision=None, preferred_element_type=None)

0 commit comments

Comments
 (0)