Skip to content

Commit 9fbc1c1

Browse files
Merge pull request #221 from ROCm/ci-upstream-sync-106_1
CI: 02/04/25 upstream sync
2 parents 0962b96 + c20bb5b commit 9fbc1c1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1449
-832
lines changed

.github/workflows/ci-build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ jobs:
9696
echo "JAX_THREEFRY_PARTITIONABLE=$JAX_THREEFRY_PARTITIONABLE"
9797
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
9898
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
99-
pytest -n auto --tb=short --maxfail=20 tests examples
99+
pytest -n 4 --tb=short --maxfail=20 tests examples
100100
101101
102102
documentation:

.github/workflows/tsan.yaml

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ jobs:
4444
repository: python/cpython
4545
path: cpython
4646
ref: "3.13"
47+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
48+
with:
49+
repository: numpy/numpy
50+
path: numpy
51+
submodules: true
4752

4853
- name: Restore cached CPython with TSAN
4954
id: cache-cpython-tsan-restore
@@ -67,7 +72,7 @@ jobs:
6772
# Create archive to be used with bazel as hermetic python:
6873
cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan
6974
70-
- name: Save CPython with TSAN
75+
- name: Save TSAN CPython
7176
id: cache-cpython-tsan-save
7277
if: steps.cache-cpython-tsan-restore.outputs.cache-hit != 'true'
7378
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
@@ -76,6 +81,73 @@ jobs:
7681
./python-tsan.tgz
7782
key: ${{ runner.os }}-cpython-tsan-${{ hashFiles('cpython/configure.ac') }}
7883

84+
- name: Get year & week number
85+
id: get-date
86+
run: echo "date=$(/bin/date "+%Y-%U")" >> $GITHUB_OUTPUT
87+
shell: bash -l {0}
88+
89+
- name: Restore cached TSAN Numpy
90+
id: cache-numpy-tsan-restore
91+
uses: actions/cache/restore@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
92+
with:
93+
path: |
94+
./wheelhouse
95+
key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }}
96+
97+
- name: Build TSAN Numpy wheel
98+
if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true'
99+
run: |
100+
cd numpy
101+
102+
# If we restored cpython from cache, we need to get python interpreter from python-tsan.tgz
103+
if [ ! -d ${GITHUB_WORKSPACE}/cpython-tsan/bin/ ]; then
104+
echo "Extract cpython from python-tsan.tgz"
105+
pushd .
106+
ls ${GITHUB_WORKSPACE}/python-tsan.tgz
107+
cd ${GITHUB_WORKSPACE} && tar -xzf python-tsan.tgz
108+
ls ${GITHUB_WORKSPACE}/cpython-tsan/bin/
109+
popd
110+
fi
111+
112+
export PATH=${GITHUB_WORKSPACE}/cpython-tsan/bin/:$PATH
113+
114+
python3 -m pip install -r requirements/build_requirements.txt
115+
# Make sure to install a compatible Cython version (master branch is best for now)
116+
python3 -m pip install -U git+https://github.com/cython/cython
117+
118+
CC=clang-18 CXX=clang++-18 python3 -m pip wheel --wheel-dir dist -v . --no-build-isolation -Csetup-args=-Db_sanitize=thread -Csetup-args=-Dbuildtype=debugoptimized
119+
120+
# Create simple index and copy the wheel
121+
mkdir -p ${GITHUB_WORKSPACE}/wheelhouse/numpy
122+
123+
numpy_whl_name=($(cd dist && ls numpy*.whl))
124+
if [ -z "${numpy_whl_name}" ]; then exit 1; fi
125+
126+
echo "Built TSAN Numpy wheel: ${numpy_whl_name}"
127+
128+
cp dist/${numpy_whl_name} ${GITHUB_WORKSPACE}/wheelhouse/numpy
129+
130+
cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/index.html
131+
<!DOCTYPE html><html><body>
132+
<a href="numpy">numpy></a></br>
133+
</body></html>
134+
EOF
135+
136+
cat << EOF > ${GITHUB_WORKSPACE}/wheelhouse/numpy/index.html
137+
<!DOCTYPE html><html><body>
138+
<a href="${numpy_whl_name}">${numpy_whl_name}</a></br>
139+
</body></html>
140+
EOF
141+
142+
- name: Save TSAN Numpy wheel
143+
id: cache-numpy-tsan-save
144+
if: steps.cache-numpy-tsan-restore.outputs.cache-hit != 'true'
145+
uses: actions/cache/save@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
146+
with:
147+
path: |
148+
./wheelhouse
149+
key: ${{ runner.os }}-numpy-tsan-${{ hashFiles('numpy/pyproject.toml') }}-${{ steps.get-date.outputs.date }}
150+
79151
- name: Build Jax and run tests
80152
timeout-minutes: 120
81153
env:

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2121
decorator to support customizing the behavior of opaque functions under
2222
JAX-level dead code elimination (DCE). See {jax-issue}`#25956` for more
2323
details.
24+
* Added {func}`jax.random.multinomial`.
2425

2526
* Changes
2627
* `JAX_CPU_COLLECTIVES_IMPLEMENTATION` and `JAX_NUM_CPU_DEVICES` now work as

build/test-requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ flatbuffers
77
hypothesis
88
mpmath>=1.3
99
pillow>=10.4.0
10-
portpicker
10+
# TODO(kanglan): Remove once psutil from portpicker supports python 3.13t
11+
portpicker; python_version<"3.13"
1112
pytest-xdist
1213
wheel
1314
rich

docs/autodidax.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@
146146
"around calls to `bind`. These wrappers let us control how arguments are passed\n",
147147
"to `bind`, and in particular we follow a handy internal convention: when we\n",
148148
"call `bind`, we pass values representing array data as positional arguments,\n",
149-
"and we pass metadata like the `axis` argument to `sum_p` via keyword. This\n",
149+
"and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This\n",
150150
"calling convention simplifies some core logic (since e.g. instances of the\n",
151151
"`Tracer` class to be defined below can only occur in positional arguments to\n",
152152
"`bind`). The wrappers can also provide docstrings!\n",

docs/autodidax.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ The functions that user code calls, like `add` and `sin`, are just wrappers
133133
around calls to `bind`. These wrappers let us control how arguments are passed
134134
to `bind`, and in particular we follow a handy internal convention: when we
135135
call `bind`, we pass values representing array data as positional arguments,
136-
and we pass metadata like the `axis` argument to `sum_p` via keyword. This
136+
and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
137137
calling convention simplifies some core logic (since e.g. instances of the
138138
`Tracer` class to be defined below can only occur in positional arguments to
139139
`bind`). The wrappers can also provide docstrings!

docs/autodidax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def bind1(prim, *args, **params):
123123
# around calls to `bind`. These wrappers let us control how arguments are passed
124124
# to `bind`, and in particular we follow a handy internal convention: when we
125125
# call `bind`, we pass values representing array data as positional arguments,
126-
# and we pass metadata like the `axis` argument to `sum_p` via keyword. This
126+
# and we pass metadata like the `axis` argument to `reduce_sum_p` via keyword. This
127127
# calling convention simplifies some core logic (since e.g. instances of the
128128
# `Tracer` class to be defined below can only occur in positional arguments to
129129
# `bind`). The wrappers can also provide docstrings!

docs/jax.random.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ Random Samplers
5353
logistic
5454
lognormal
5555
maxwell
56+
multinomial
5657
multivariate_normal
5758
normal
5859
orthogonal

docs/persistent_compilation_cache.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,36 @@ so it is important for the persistent cache to be in a shared file system (eg: N
168168
If the persistent cache is local to rank 0, then all processes except rank 0 will once again compile
169169
in subsequent runs as a result of a compilation cache miss.
170170

171+
### Pre-compiling multi-node programs on single node
172+
173+
JAX can populate the compilation cache with compiled programs for multiple nodes
174+
on a single node. Preparing the cache on a single node helps to decrease the costly
175+
compilation time on a cluster. To compile and run multi-node programs on a single
176+
node, users can create fake remote devices using
177+
the `jax_mock_gpu_topology` configuration option.
178+
179+
For instance, the snippet below instructs JAX to mock a cluster with four
180+
nodes, each node running eight processes with each process attached to one GPU.
181+
182+
```python
183+
jax.config.update("jax_mock_gpu_topology", "4x8x1")
184+
```
185+
186+
After populating the cache with this config, users can run the program
187+
without recompilation on four nodes, eight processes per node,
188+
one GPU per process.
189+
190+
Important notes:
191+
192+
* The process running the mocked program must have the same amount of GPUs
193+
and the same GPU model as the nodes that would use the cache. For instance,
194+
a mocked topology `8x4x2` must run in a process with two GPUs.
195+
196+
* When running programs with mocked topology, the results of communications
197+
with other nodes are undefined, so the outputs of JAX programs running
198+
in mocked environments will likely be incorrect.
199+
200+
171201
## Logging cache activity
172202

173203
It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.

jax/_src/abstract_arrays.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,24 +45,28 @@
4545

4646

4747
def masked_array_error(*args, **kwargs):
48-
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
49-
"Use arr.filled() to convert the value to a standard numpy array.")
48+
raise ValueError(
49+
"numpy masked arrays are not supported as direct inputs to JAX functions."
50+
" Use arr.filled() to convert the value to a standard numpy array.")
5051

5152
core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
5253

5354

5455
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
5556
dtype = x.dtype
5657
dtypes.check_valid_dtype(dtype)
57-
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
58+
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype),
59+
sharding=core.get_cur_mesh_sharding(core.P(*[None] * x.ndim)))
5860

5961
core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
6062

6163

6264
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
6365
dtype = np.dtype(x)
6466
dtypes.check_valid_dtype(dtype)
65-
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
67+
shape = np.shape(x)
68+
return ShapedArray(shape, dtypes.canonicalize_dtype(dtype),
69+
sharding=core.get_cur_mesh_sharding(core.P(*[None] * len(shape))))
6670

6771
for t in numpy_scalar_types:
6872
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
@@ -74,7 +78,8 @@ def _make_abstract_python_scalar(typ, val):
7478
# Note: all python scalar types are weak except bool, because bool only
7579
# comes in a single width.
7680
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
77-
weak_type=typ is not bool)
81+
weak_type=typ is not bool,
82+
sharding=core.get_cur_mesh_sharding())
7883

7984
for t in dtypes.python_scalar_dtypes:
8085
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)

0 commit comments

Comments
 (0)