Skip to content

Commit c46b4fc

Browse files
authored
Merge pull request #294 from ROCm/ci-upstream-sync-151_1
CI: 03/18/25 upstream sync
2 parents d864b4f + c7b407c commit c46b4fc

File tree

92 files changed

+1843
-373
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

92 files changed

+1843
-373
lines changed

.github/workflows/pytest_cpu.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@ jobs:
118118
run: |
119119
$JAXCI_PYTHON -m pip install uv~=0.5.30
120120
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
121+
122+
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
123+
if [[ $OS == "linux" && $ARCH == "aarch64" ]]; then
124+
$JAXCI_PYTHON -m uv pip install numpy~=2.1.0
125+
fi
121126
# Halt for testing
122127
- name: Wait For Connection
123128
uses: google-ml-infra/actions/ci_connection@main

.github/workflows/pytest_cuda.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ jobs:
5454
runs-on: ${{ inputs.runner }}
5555
# TODO: Update to the generic ML ecosystem test containers when they are ready.
5656
container: ${{ (contains(inputs.cuda, '12.3') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.3-cudnn9.1:latest') ||
57-
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') }}
57+
(contains(inputs.cuda, '12.1') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.1-cudnn9.1:latest') ||
58+
(contains(inputs.cuda, '12.8') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-cuda12.8-cudnn9.8:latest') }}
5859
name: "Pytest CUDA (${{ inputs.runner }}, CUDA ${{ inputs.cuda }}, Python ${{ inputs.python }}, x64=${{ inputs.enable-x64 }})"
5960

6061
env:

.github/workflows/wheel_tests_continuous.yml

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,30 @@ jobs:
110110
fail-fast: false # don't cancel all jobs on failure
111111
matrix:
112112
# Python values need to match the matrix stategy in the artifact build jobs above
113-
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"]
113+
# See exlusions for what is fully tested
114+
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu","linux-x86-a4-224-b200-1gpu"]
114115
python: ["3.10",]
115-
cuda: ["12.3", "12.1"]
116+
cuda: ["12.1","12.3","12.8"]
116117
enable-x64: [1, 0]
117118
exclude:
118-
# Run only a single configuration on H100 to save resources
119+
# L4 does not run on cuda 12.8 but tests other configs
120+
- runner: "linux-x86-g2-48-l4-4gpu"
121+
cuda: "12.8"
122+
# H100 runs only a single config, CUDA 12.3 Enable x64 1
123+
- runner: "linux-x86-a3-8g-h100-8gpu"
124+
cuda: "12.8"
119125
- runner: "linux-x86-a3-8g-h100-8gpu"
120-
python: "3.10"
121126
cuda: "12.1"
122127
- runner: "linux-x86-a3-8g-h100-8gpu"
123-
python: "3.10"
124-
enable-x64: 0
128+
enable-x64: "0"
129+
# B200 runs only a single config, CUDA 12.8 Enable x64 1
130+
- runner: "linux-x86-a4-224-b200-1gpu"
131+
enable-x64: "0"
132+
- runner: "linux-x86-a4-224-b200-1gpu"
133+
cuda: "12.1"
134+
- runner: "linux-x86-a4-224-b200-1gpu"
135+
cuda: "12.3"
136+
125137
name: "Pytest CUDA (JAX artifacts version = ${{ format('{0}', 'head') }})"
126138
with:
127139
runner: ${{ matrix.runner }}

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
2222
{func}`jax.lax.dynamic_update_slice` and related functions. The default is
2323
true, matching the current behavior. If set to false, JAX does not need to
2424
emit code clamping negative indices, which improves code size.
25+
* Added a `replace` option to {func}`jax.random.categorical` to enable sampling
26+
without replacement.
2527

2628
## jax 0.5.2 (Mar 4, 2025)
2729

build/test-requirements.txt

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,4 @@ setuptools
1818
matplotlib~=3.8.4; python_version=="3.10"
1919
matplotlib; python_version>="3.11"
2020
opt-einsum
21-
auditwheel
22-
23-
# CPU Pytests crash with NumPy 2.2+ on Linux Aarch64; b/399168632
24-
numpy~=2.1.0; platform_system == "Linux" and platform_machine == "aarch64"
21+
auditwheel

docs/notebooks/explicit-sharding.ipynb

Lines changed: 104 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,9 @@
4949
},
5050
{
5151
"cell_type": "code",
52-
"execution_count": 1,
52+
"execution_count": 7,
5353
"metadata": {
54-
"colab": {
55-
"base_uri": "https://localhost:8080/"
56-
},
57-
"id": "hVi6mApuVw3r",
58-
"outputId": "a64bcbcb-27f8-4c57-8931-8091c9bb8ebf"
54+
"id": "hVi6mApuVw3r"
5955
},
6056
"outputs": [],
6157
"source": [
@@ -84,13 +80,13 @@
8480
},
8581
{
8682
"cell_type": "code",
87-
"execution_count": 2,
83+
"execution_count": 8,
8884
"metadata": {
8985
"colab": {
9086
"base_uri": "https://localhost:8080/"
9187
},
9288
"id": "mzDIDvj7Vw0k",
93-
"outputId": "417b8453-9c86-4e76-a886-4fa9fdb16434"
89+
"outputId": "09ef049b-461f-47db-bf58-dc10b42fe40a"
9490
},
9591
"outputs": [
9692
{
@@ -119,13 +115,13 @@
119115
},
120116
{
121117
"cell_type": "code",
122-
"execution_count": 3,
118+
"execution_count": 9,
123119
"metadata": {
124120
"colab": {
125121
"base_uri": "https://localhost:8080/"
126122
},
127123
"id": "IyPx_-IBVwxr",
128-
"outputId": "7d6e4fcb-f6a8-4ed8-ae41-61cf478fa499"
124+
"outputId": "0cd3122f-e579-45d7-868d-e42bb0eacddb"
129125
},
130126
"outputs": [
131127
{
@@ -141,7 +137,7 @@
141137
"Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32)"
142138
]
143139
},
144-
"execution_count": 3,
140+
"execution_count": 9,
145141
"metadata": {},
146142
"output_type": "execute_result"
147143
}
@@ -172,13 +168,13 @@
172168
},
173169
{
174170
"cell_type": "code",
175-
"execution_count": 4,
171+
"execution_count": 10,
176172
"metadata": {
177173
"colab": {
178174
"base_uri": "https://localhost:8080/"
179175
},
180176
"id": "NO2ulM_QW7a8",
181-
"outputId": "ea313610-146c-41f4-95b4-c5a5b2b407cb"
177+
"outputId": "d888371b-080e-4bff-be5d-ea56beda3aac"
182178
},
183179
"outputs": [
184180
{
@@ -208,13 +204,13 @@
208204
},
209205
{
210206
"cell_type": "code",
211-
"execution_count": 5,
207+
"execution_count": 11,
212208
"metadata": {
213209
"colab": {
214210
"base_uri": "https://localhost:8080/"
215211
},
216212
"id": "1-TzmA0AXCAf",
217-
"outputId": "15b33b6d-3915-4725-da6d-4f31fb78fe71"
213+
"outputId": "1c7cc3ac-4b0e-42b7-facc-c706af10d7d2"
218214
},
219215
"outputs": [
220216
{
@@ -256,13 +252,13 @@
256252
},
257253
{
258254
"cell_type": "code",
259-
"execution_count": 6,
255+
"execution_count": 12,
260256
"metadata": {
261257
"colab": {
262258
"base_uri": "https://localhost:8080/"
263259
},
264260
"id": "Gy7ABds3XND3",
265-
"outputId": "4ced73ed-5872-45f3-a4a6-2138f942e01b"
261+
"outputId": "0d72dad2-381a-4e96-f771-40d705da1376"
266262
},
267263
"outputs": [
268264
{
@@ -297,13 +293,13 @@
297293
},
298294
{
299295
"cell_type": "code",
300-
"execution_count": 7,
296+
"execution_count": 13,
301297
"metadata": {
302298
"colab": {
303299
"base_uri": "https://localhost:8080/"
304300
},
305301
"id": "grCcotr-XQjY",
306-
"outputId": "9a9f381d-5111-4824-9bc0-cb2472cb8e6a"
302+
"outputId": "c2db656c-809f-49a6-c948-629d6420360c"
307303
},
308304
"outputs": [
309305
{
@@ -324,7 +320,7 @@
324320
" [ 3, 4, 5, 6, 7, 8, 9, 10]], dtype=int32)"
325321
]
326322
},
327-
"execution_count": 7,
323+
"execution_count": 13,
328324
"metadata": {},
329325
"output_type": "execute_result"
330326
}
@@ -460,13 +456,13 @@
460456
},
461457
{
462458
"cell_type": "code",
463-
"execution_count": 13,
459+
"execution_count": 14,
464460
"metadata": {
465461
"colab": {
466462
"base_uri": "https://localhost:8080/"
467463
},
468464
"id": "fpFEaMBcXsJG",
469-
"outputId": "d28a69eb-260f-4fc5-8f19-2cc64cc70660"
465+
"outputId": "5b84b1d1-d7b2-4e9a-ba98-3dd34a5465ef"
470466
},
471467
"outputs": [
472468
{
@@ -479,13 +475,6 @@
479475
"We're in auto-sharding mode here. This is the current mesh: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Auto))\n",
480476
"Result type: ShapedArray(int32[4@X,4])\n"
481477
]
482-
},
483-
{
484-
"name": "stdout",
485-
"output_type": "stream",
486-
"text": [
487-
"Result type: ShapedArray(int32[4@X,4])\n"
488-
]
489478
}
490479
],
491480
"source": [
@@ -550,13 +539,13 @@
550539
},
551540
{
552541
"cell_type": "code",
553-
"execution_count": 10,
542+
"execution_count": 15,
554543
"metadata": {
555544
"colab": {
556545
"base_uri": "https://localhost:8080/"
557546
},
558547
"id": "geptWrdYX0OM",
559-
"outputId": "c0e62eb1-9f79-4d1c-e708-526165ca680f"
548+
"outputId": "b8c3813f-60bb-4ccf-9da7-73462c57963f"
560549
},
561550
"outputs": [
562551
{
@@ -588,7 +577,88 @@
588577
{
589578
"cell_type": "markdown",
590579
"metadata": {
591-
"id": "AQQjzUeGX4P6"
580+
"id": "LZWjgiMZ7uSS"
581+
},
582+
"source": [
583+
"You can use the `auto_axes` API to be `Auto` over some mesh axes while being `Explicit` over other. For example:"
584+
]
585+
},
586+
{
587+
"cell_type": "code",
588+
"execution_count": 27,
589+
"metadata": {
590+
"colab": {
591+
"base_uri": "https://localhost:8080/"
592+
},
593+
"id": "IVzPSkp77uCF",
594+
"outputId": "db80a604-98ac-4343-8677-23729adf7ffc"
595+
},
596+
"outputs": [
597+
{
598+
"name": "stdout",
599+
"output_type": "stream",
600+
"text": [
601+
"mesh inside f: AbstractMesh('X': 2, 'Y': 4, axis_types=(Explicit, Explicit))\n",
602+
"x.sharding: ShapedArray(float32[4@X,4@Y])\n",
603+
"\n",
604+
"mesh inside g: AbstractMesh('X': 2, 'Y': 4, axis_types=(Auto, Explicit))\n",
605+
"y.sharding inside g: jax.typeof(y) = ShapedArray(float32[4,4@Y])\n",
606+
"\n",
607+
"z.sharding: ShapedArray(float32[4@X,4@Y])\n",
608+
"\n"
609+
]
610+
},
611+
{
612+
"data": {
613+
"text/plain": [
614+
"Array([[ 1. , 2.682942 , 2.818595 , 1.28224 ],\n",
615+
" [-0.513605 , -0.9178486 , 0.44116902, 2.3139732 ],\n",
616+
" [ 2.9787164 , 1.824237 , -0.08804226, -0.99998045],\n",
617+
" [-0.07314587, 1.840334 , 2.9812148 , 2.3005757 ]], dtype=float32)"
618+
]
619+
},
620+
"execution_count": 27,
621+
"metadata": {},
622+
"output_type": "execute_result"
623+
}
624+
],
625+
"source": [
626+
"import functools\n",
627+
"\n",
628+
"@functools.partial(auto_axes, axes='X')\n",
629+
"def g(y):\n",
630+
" print(f'mesh inside g: {get_abstract_mesh()}')\n",
631+
" print(f'y.sharding inside g: {jax.typeof(y) = }', end='\\n\\n')\n",
632+
" return y * 2\n",
633+
"\n",
634+
"@jax.jit\n",
635+
"def f(arr1):\n",
636+
" print(f'mesh inside f: {get_abstract_mesh()}')\n",
637+
" x = jnp.sin(arr1)\n",
638+
" print(f'x.sharding: {jax.typeof(x)}', end='\\n\\n')\n",
639+
"\n",
640+
" z = g(x, out_shardings=P(\"X\", \"Y\"))\n",
641+
"\n",
642+
" print(f'z.sharding: {jax.typeof(z)}', end=\"\\n\\n\")\n",
643+
" return z + 1\n",
644+
"\n",
645+
"some_x = reshard(np.arange(16).reshape(4, 4), P(\"X\", \"Y\"))\n",
646+
"f(some_x)"
647+
]
648+
},
649+
{
650+
"cell_type": "markdown",
651+
"metadata": {
652+
"id": "_3sfJjRq8w9f"
653+
},
654+
"source": [
655+
"As you can see, inside `g`, the type of `arr1` is `ShapedArray(float32[4,4@Y])` which indicates it's Explicit over `Y` mesh axis while auto over `X`."
656+
]
657+
},
658+
{
659+
"cell_type": "markdown",
660+
"metadata": {
661+
"id": "sJcWbfAh7UcO"
592662
},
593663
"source": [
594664
"## Concrete array shardings can mention `Auto` mesh axis\n",
@@ -606,7 +676,7 @@
606676
},
607677
{
608678
"cell_type": "code",
609-
"execution_count": 25,
679+
"execution_count": null,
610680
"metadata": {
611681
"colab": {
612682
"base_uri": "https://localhost:8080/"
@@ -708,5 +778,5 @@
708778
}
709779
},
710780
"nbformat": 4,
711-
"nbformat_minor": 4
781+
"nbformat_minor": 0
712782
}

0 commit comments

Comments
 (0)