Skip to content

Commit b505df9

Browse files
Merge pull request #299 from ROCm/ci-upstream-sync-152_1
CI: 03/19/25 upstream sync
2 parents 1f2fe33 + e9ce8fb commit b505df9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1048
-649
lines changed

.github/workflows/requirements_lock_3_13_ft.patch

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
diff --git a/build/requirements_lock_3_13_ft.txt b/build/requirements_lock_3_13_ft.txt
2-
index dfefaf042..2700e140e 100644
2+
index e7a2968e9..d37e11ee3 100644
33
--- a/build/requirements_lock_3_13_ft.txt
44
+++ b/build/requirements_lock_3_13_ft.txt
5-
@@ -4,6 +4,12 @@
5+
@@ -4,6 +4,11 @@
66
#
77
# pip-compile --allow-unsafe --generate-hashes --output-file=build/requirements_lock_3_13_ft.txt build/requirements.in
88
#
99
+
1010
+--pre
1111
+--extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
1212
+numpy
13-
+
1413
+
1514
absl-py==2.1.0 \
1615
--hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \
1716
--hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff
18-
@@ -328,68 +334,6 @@ mpmath==1.3.0 \
17+
@@ -328,68 +333,6 @@ mpmath==1.3.0 \
1918
--hash=sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f \
2019
--hash=sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c
2120
# via -r build/test-requirements.txt
@@ -81,6 +80,6 @@ index dfefaf042..2700e140e 100644
8180
- # matplotlib
8281
- # ml-dtypes
8382
- # scipy
84-
opt-einsum==3.4.0 \
85-
--hash=sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd \
86-
--hash=sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac
83+
nvidia-cublas-cu12==12.8.3.14 ; sys_platform == "linux" \
84+
--hash=sha256:3f0e05e7293598cf61933258b73e66a160c27d59c4422670bf0b79348c04be44 \
85+
--hash=sha256:93a4e0e386cc7f6e56c822531396de8170ed17068a1e18f987574895044cd8c3 \

.github/workflows/tsan.yaml

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,18 @@ jobs:
173173
--bazel_options=--copt=-g \
174174
--clang_path=/usr/bin/clang-18
175175
176-
# Update the patch to use TSAN instrumented numpy
176+
# Patch build/requirements_lock_3_13_ft.txt to use TSAN instrumented NumPy
177177
sed -i "s|+--extra-index-url.*|+--extra-index-url file://${GITHUB_WORKSPACE}/wheelhouse/|" .github/workflows/requirements_lock_3_13_ft.patch
178178
cat .github/workflows/requirements_lock_3_13_ft.patch
179+
git apply .github/workflows/requirements_lock_3_13_ft.patch || exit 1
179180
180-
# Apply a patch to numpy in requirements lock 3.13 ft to use the nightly version
181-
git apply .github/workflows/requirements_lock_3_13_ft.patch
181+
# Display the content for debugging in logs
182+
cat build/requirements_lock_3_13_ft.txt | head -15
183+
# Check the patch
184+
cat build/requirements_lock_3_13_ft.txt | head -15 | grep -E "(--pre|.*${GITHUB_WORKSPACE}/wheelhouse/|numpy)"
185+
if [ "$?" == "1" ]; then echo "Could not find the patch in the requirements_lock_3_13_ft.txt"; exit 1; fi
186+
cat build/requirements_lock_3_13_ft.txt | grep -E "(numpy==)"
187+
if [ "$?" == "0" ]; then "Found original numpy dependency in the requirements_lock_3_13_ft.txt"; exit 1; fi
182188
183189
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
184190
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
@@ -188,6 +194,13 @@ jobs:
188194
bazel_exec=($(ls bazel-*))
189195
ln -s ${bazel_exec} bazel
190196
197+
# Check python version
198+
./bazel run --@rules_python//python/config_settings:py_freethreaded="yes" @python//:python3 -- -VV
199+
200+
# Check numpy version
201+
./bazel cquery @pypi_numpy//:* | grep whl
202+
203+
# Build JAX and run tests
191204
./bazel test \
192205
--test_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \
193206
--test_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \

jax/_src/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from jax._src import profiler
3434
from jax._src import util
3535
from jax._src import xla_bridge
36-
from jax._src.mesh import use_concrete_mesh
3736
from jax._src.interpreters import mlir
3837
from jax._src.interpreters import pxla
3938
from jax._src.interpreters import xla
@@ -43,7 +42,8 @@
4342
from jax._src.sharding import Sharding
4443
from jax._src.sharding_impls import (
4544
PmapSharding, SingleDeviceSharding,
46-
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
45+
device_replica_id_map, hashed_index, num_addressable_indices,
46+
local_to_global_shape, use_concrete_mesh) # pyformat: disable
4747
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
4848
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
4949
import numpy as np

jax/_src/blocked_sampler.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ def __call__(self, key: ArrayLike, *args, shape: Shape,
2929

3030

3131
def _compute_tile_index(block_index: Sequence[int],
32-
total_size_in_blocks: Shape,
3332
block_size_in_tiles: Shape,
33+
total_size_in_tiles: Shape,
3434
tile_index_in_block: Sequence[int]) -> int:
3535
ndims = len(block_index)
3636
dim_size = 1
3737
total_idx = 0
3838
for i in range(ndims-1, -1, -1):
3939
dim_idx = tile_index_in_block[i] + block_index[i] * block_size_in_tiles[i]
4040
total_idx += dim_idx * dim_size
41-
dim_size *= total_size_in_blocks[i] * block_size_in_tiles[i]
41+
dim_size *= total_size_in_tiles[i]
4242
return total_idx
4343

4444

@@ -103,15 +103,17 @@ def blocked_fold_in(
103103
_shape // _element for _shape, _element in zip(block_size, tile_size)
104104
)
105105

106-
total_size_in_blocks = tuple(
107-
_shape // _element for _shape, _element in zip(total_size, block_size)
106+
# Round up to make sure every tile is numbered.
107+
total_size_in_tiles = tuple(
108+
(_shape + _element - 1) // _element
109+
for _shape, _element in zip(total_size, tile_size)
108110
)
109111

110112
def _keygen_loop(axis, prefix):
111113
if axis == len(block_size_in_tiles):
112114
subtile_key = jax.random.fold_in(
113115
global_key, _compute_tile_index(
114-
block_index, total_size_in_blocks, block_size_in_tiles, prefix))
116+
block_index, block_size_in_tiles, total_size_in_tiles, prefix))
115117
return subtile_key
116118
else:
117119
keys = []

jax/_src/custom_partitioning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
179179
for sharding, s in zip(result_shardings, result_shapes)
180180
]
181181
closed_jaxpr = jax.make_jaxpr(lower_fn, axis_env=list(mesh.shape.items()))(
182-
*tiled_args
182+
*info.in_tree.unflatten(tiled_args)
183183
)
184184
if ([(o.shape, o.dtype) for o in closed_jaxpr.out_avals] !=
185185
[(t.shape, t.dtype) for t in tiled_results]):

jax/_src/interpreters/partial_eval.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
JaxprEqn, Primitive, ShapedArray, DShapedArray,
4242
mapped_aval, unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
4343
InputType, OutputType, get_referent, JaxprEqnContext)
44-
from jax._src.state.types import AbstractRef
44+
from jax._src.state.types import AbstractRef, ReadEffect
4545
from jax._src.tree_util import (PyTreeDef, treedef_tuple,
4646
tree_flatten, tree_structure)
4747
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
@@ -1423,7 +1423,8 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
14231423

14241424

14251425
def has_effects(eqn: JaxprEqn) -> bool:
1426-
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)}
1426+
effs = {e for e in eqn.effects if not isinstance(e, core.NamedAxisEffect)
1427+
and not isinstance(e, ReadEffect)}
14271428
return bool(effs)
14281429

14291430

jax/_src/lax/lax.py

Lines changed: 143 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -615,8 +615,23 @@ def tanh(x: ArrayLike) -> Array:
615615
"""
616616
return tanh_p.bind(x)
617617

618+
@export
618619
def logistic(x: ArrayLike) -> Array:
619-
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`."""
620+
r"""Elementwise logistic (sigmoid) function: :math:`\frac{1}{1 + e^{-x}}`.
621+
622+
There is no HLO logistic/sigmoid primitive, so this lowers to a sequence
623+
of HLO arithmetic operations.
624+
625+
Args:
626+
x: input array. Must have floating point or complex dtype.
627+
628+
Returns:
629+
Array of the same shape and dtype as ``x`` containing the element-wise
630+
logistic/sigmoid function.
631+
632+
See also:
633+
- :func:`jax.nn.sigmoid`: an alternative API for this functionality.
634+
"""
620635
return logistic_p.bind(x)
621636

622637
@export
@@ -1018,12 +1033,45 @@ def bitwise_xor(x: ArrayLike, y: ArrayLike) -> Array:
10181033
"""
10191034
return xor_p.bind(x, y)
10201035

1036+
@export
10211037
def population_count(x: ArrayLike) -> Array:
1022-
r"""Elementwise popcount, count the number of set bits in each element."""
1038+
r"""Elementwise popcount, count the number of set bits in each element.
1039+
1040+
This function lowers directly to the `stablehlo.popcnt`_ operation.
1041+
1042+
Args:
1043+
x: Input array. Must have integer dtype.
1044+
1045+
Returns:
1046+
An array of the same shape and dtype as ``x``, containing the number of
1047+
set bits in the input.
1048+
1049+
See also:
1050+
- :func:`jax.lax.clz`: Elementwise count leading zeros.
1051+
- :func:`jax.numpy.bitwise_count`: More flexible NumPy-style API for bit counts.
1052+
1053+
.. _stablehlo.popcnt: https://openxla.org/stablehlo/spec#popcnt
1054+
"""
10231055
return population_count_p.bind(x)
10241056

1057+
@export
10251058
def clz(x: ArrayLike) -> Array:
1026-
r"""Elementwise count-leading-zeros."""
1059+
r"""Elementwise count-leading-zeros.
1060+
1061+
This function lowers directly to the `stablehlo.count_leading_zeros`_ operation.
1062+
1063+
Args:
1064+
x: Input array. Must have integer dtype.
1065+
1066+
Returns:
1067+
An array of the same shape and dtype as ``x``, containing the number of
1068+
set bits in the input.
1069+
1070+
See also:
1071+
- :func:`jax.lax.population_count`: Count the number of set bits in each element.
1072+
1073+
.. _stablehlo.count_leading_zeros: https://openxla.org/stablehlo/spec#count_leading_zeros
1074+
"""
10271075
return clz_p.bind(x)
10281076

10291077
@export
@@ -1124,31 +1172,81 @@ def div(x: ArrayLike, y: ArrayLike) -> Array:
11241172
"""
11251173
return div_p.bind(x, y)
11261174

1175+
@export
11271176
def rem(x: ArrayLike, y: ArrayLike) -> Array:
11281177
r"""Elementwise remainder: :math:`x \bmod y`.
11291178
1130-
The sign of the result is taken from the dividend,
1131-
and the absolute value of the result is always
1132-
less than the divisor's absolute value.
1179+
This function lowers directly to the `stablehlo.remainder`_ operation.
1180+
The sign of the result is taken from the dividend, and the absolute value
1181+
of the result is always less than the divisor's absolute value.
11331182
1134-
Integer division overflow
1135-
(remainder by zero or remainder of INT_SMIN with -1)
1183+
Integer division overflow (remainder by zero or remainder of INT_SMIN with -1)
11361184
produces an implementation defined value.
1185+
1186+
Args:
1187+
x, y: Input arrays. Must have matching int or float dtypes. If neither
1188+
is a scalar, ``x`` and ``y`` must have the same number of dimensions
1189+
and be broadcast compatible.
1190+
1191+
Returns:
1192+
An array of the same dtype as ``x`` and ``y`` containing the remainder.
1193+
1194+
See also:
1195+
- :func:`jax.numpy.remainder`: NumPy-style remainder with different
1196+
sign semantics.
1197+
1198+
.. _stablehlo.remainder: https://openxla.org/stablehlo/spec#remainder
11371199
"""
11381200
return rem_p.bind(x, y)
11391201

1202+
@export
11401203
def max(x: ArrayLike, y: ArrayLike) -> Array:
1141-
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`
1204+
r"""Elementwise maximum: :math:`\mathrm{max}(x, y)`.
1205+
1206+
This function lowers directly to the `stablehlo.maximum`_ operation for
1207+
non-complex inputs. For complex numbers, this uses a lexicographic
1208+
comparison on the `(real, imaginary)` pairs.
1209+
1210+
Args:
1211+
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
1212+
``x`` and ``y`` must have the same rank and be broadcast compatible.
11421213
1143-
For complex numbers, uses a lexicographic comparison on the
1144-
`(real, imaginary)` pairs."""
1214+
Returns:
1215+
An array of the same dtype as ``x`` and ``y`` containing the elementwise
1216+
maximum.
1217+
1218+
See also:
1219+
- :func:`jax.numpy.maximum`: more flexibly NumPy-style maximum.
1220+
- :func:`jax.lax.reduce_max`: maximum along an axis of an array.
1221+
- :func:`jax.lax.min`: elementwise minimum.
1222+
1223+
.. _stablehlo.maximum: https://openxla.org/stablehlo/spec#maximum
1224+
"""
11451225
return max_p.bind(x, y)
11461226

1227+
@export
11471228
def min(x: ArrayLike, y: ArrayLike) -> Array:
1148-
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
1229+
r"""Elementwise minimum: :math:`\mathrm{min}(x, y)`
1230+
1231+
This function lowers directly to the `stablehlo.minimum`_ operation for
1232+
non-complex inputs. For complex numbers, this uses a lexicographic
1233+
comparison on the `(real, imaginary)` pairs.
1234+
1235+
Args:
1236+
x, y: Input arrays. Must have matching dtypes. If neither is a scalar,
1237+
``x`` and ``y`` must have the same rank and be broadcast compatible.
11491238
1150-
For complex numbers, uses a lexicographic comparison on the
1151-
`(real, imaginary)` pairs."""
1239+
Returns:
1240+
An array of the same dtype as ``x`` and ``y`` containing the elementwise
1241+
minimum.
1242+
1243+
See also:
1244+
- :func:`jax.numpy.minimum`: more flexibly NumPy-style minimum.
1245+
- :func:`jax.lax.reduce_min`: minimum along an axis of an array.
1246+
- :func:`jax.lax.max`: elementwise maximum.
1247+
1248+
.. _stablehlo.minimum: https://openxla.org/stablehlo/spec#minimum
1249+
"""
11521250
return min_p.bind(x, y)
11531251

11541252
@export
@@ -1408,21 +1506,38 @@ def lt(x: ArrayLike, y: ArrayLike) -> Array:
14081506
"""
14091507
return lt_p.bind(x, y)
14101508

1509+
@export
14111510
def convert_element_type(operand: ArrayLike,
14121511
new_dtype: DTypeLike | dtypes.ExtendedDType) -> Array:
14131512
"""Elementwise cast.
14141513
1415-
Wraps XLA's `ConvertElementType
1416-
<https://www.tensorflow.org/xla/operation_semantics#convertelementtype>`_
1417-
operator, which performs an elementwise conversion from one type to another.
1418-
Similar to a C++ `static_cast`.
1514+
This function lowers directly to the `stablehlo.convert`_ operation, which
1515+
performs an elementwise conversion from one type to another, similar to a
1516+
C++ ``static_cast``.
14191517
14201518
Args:
14211519
operand: an array or scalar value to be cast.
1422-
new_dtype: a NumPy dtype representing the target type.
1520+
new_dtype: a dtype-like object (e.g. a :class:`numpy.dtype`, a scalar type,
1521+
or a valid dtype name) representing the target dtype.
14231522
14241523
Returns:
1425-
An array with the same shape as `operand`, cast elementwise to `new_dtype`.
1524+
An array with the same shape as ``operand``, cast elementwise to ``new_dtype``.
1525+
1526+
.. note::
1527+
1528+
If ``new_dtype`` is a 64-bit type and `x64 mode`_ is not enabled,
1529+
the appropriate 32-bit type will be used in its place.
1530+
1531+
If the input is a JAX array and the input dtype and output dtype match, then
1532+
the input array will be returned unmodified.
1533+
1534+
See also:
1535+
- :func:`jax.numpy.astype`: NumPy-style dtype casting API.
1536+
- :meth:`jax.Array.astype`: dtype casting as an array method.
1537+
- :func:`jax.lax.bitcast_convert_type`: cast bits directly to a new dtype.
1538+
1539+
.. _stablehlo.convert: https://openxla.org/stablehlo/spec#convert
1540+
.. _x64 mode: https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
14261541
"""
14271542
return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type]
14281543

@@ -1500,12 +1615,11 @@ def _convert_element_type(
15001615
operand, new_dtype=new_dtype, weak_type=bool(weak_type),
15011616
sharding=sharding)
15021617

1618+
@export
15031619
def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
15041620
"""Elementwise bitcast.
15051621
1506-
Wraps XLA's `BitcastConvertType
1507-
<https://www.tensorflow.org/xla/operation_semantics#bitcastconverttype>`_
1508-
operator, which performs a bit cast from one type to another.
1622+
This function lowers directly to the `stablehlo.bitcast_convert`_ operation.
15091623
15101624
The output shape depends on the size of the input and output dtypes with
15111625
the following logic::
@@ -1525,6 +1639,12 @@ def bitcast_convert_type(operand: ArrayLike, new_dtype: DTypeLike) -> Array:
15251639
Returns:
15261640
An array of shape `output_shape` (see above) and type `new_dtype`,
15271641
constructed from the same bits as operand.
1642+
1643+
See also:
1644+
- :func:`jax.lax.convert_element_type`: value-preserving dtype conversion.
1645+
- :func:`jax.Array.view`: NumPy-style API for bitcast type conversion.
1646+
1647+
.. _stablehlo.bitcast_convert: https://openxla.org/stablehlo/spec#bitcast_convert
15281648
"""
15291649
new_dtype = dtypes.canonicalize_dtype(new_dtype)
15301650
return bitcast_convert_type_p.bind(operand, new_dtype=new_dtype)

0 commit comments

Comments
 (0)