Skip to content

Commit f04f164

Browse files
authored
Merge pull request #176 from ROCm/main
CI: 12/10/24 upstream sync
2 parents 048dc29 + 263d4d1 commit f04f164

File tree

160 files changed

+5457
-2089
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

160 files changed

+5457
-2089
lines changed

.bazelrc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1
104104
build:clang --copt=-Wno-gnu-offsetof-extensions
105105
# Disable clang extention that rejects unknown arguments.
106106
build:clang --copt=-Qunused-arguments
107+
# Error on struct/class mismatches, since this causes link failures on Windows.
108+
build:clang --copt=-Werror=mismatched-tags
107109

108110
# Configs for CUDA
109111
build:cuda --repo_env TF_NEED_CUDA=1

.github/workflows/ci-build.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
with:
3636
python-version: 3.11
3737
- run: python -m pip install pre-commit
38-
- uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
38+
- uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
3939
with:
4040
path: ~/.cache/pre-commit
4141
key: pre-commit-${{ env.pythonLocation }}-${{ hashFiles('.pre-commit-config.yaml', 'setup.py') }}
@@ -71,7 +71,7 @@ jobs:
7171
python -m pip install --upgrade pip wheel
7272
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
7373
- name: pip cache
74-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
74+
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
7575
with:
7676
path: ${{ steps.pip-cache.outputs.dir }}
7777
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -118,7 +118,7 @@ jobs:
118118
python -m pip install --upgrade pip wheel
119119
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
120120
- name: pip cache
121-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
121+
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
122122
with:
123123
path: ${{ steps.pip-cache.outputs.dir }}
124124
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -155,7 +155,7 @@ jobs:
155155
python -m pip install --upgrade pip wheel
156156
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
157157
- name: pip cache
158-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
158+
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
159159
with:
160160
path: ${{ steps.pip-cache.outputs.dir }}
161161
key: ${{ runner.os }}-pip-docs-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -190,7 +190,7 @@ jobs:
190190
python -m pip install --upgrade pip wheel
191191
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
192192
- name: pip cache
193-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
193+
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
194194
with:
195195
path: ${{ steps.pip-cache.outputs.dir }}
196196
key: ${{ runner.os }}-py${{ matrix.python-version }}-pip-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt') }}
@@ -229,7 +229,7 @@ jobs:
229229
python -m pip install --upgrade pip wheel
230230
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
231231
- name: pip cache
232-
uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2
232+
uses: actions/cache@1bd1e32a3bdc45362d1e726936510720a7c30a57 # v4.2.0
233233
with:
234234
path: ${{ steps.pip-cache.outputs.dir }}
235235
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}

.github/workflows/cloud-tpu-ci-nightly.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ jobs:
5454
with:
5555
repository: openxla/xla
5656
path: xla
57+
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
58+
- name: Mark GitHub workspace as safe
59+
run: |
60+
git config --global --add safe.directory "$GITHUB_WORKSPACE"
5761
- name: Install JAX test requirements
5862
run: |
5963
$PYTHON -m pip install -U -r build/test-requirements.txt
@@ -63,9 +67,11 @@ jobs:
6367
$PYTHON -m pip uninstall -y jax jaxlib libtpu
6468
if [ "${{ matrix.jaxlib-version }}" == "head" ]; then
6569
# Build and install jaxlib at head
66-
$PYTHON build/build.py --bazel_options=--config=rbe_linux_x86_64 \
67-
--bazel_options="--override_repository=xla=$(pwd)/xla" \
68-
--bazel_options=--color=yes
70+
$PYTHON build/build.py build --wheels=jaxlib \
71+
--bazel_options=--config=rbe_linux_x86_64 \
72+
--local_xla_path="$(pwd)/xla" \
73+
--verbose
74+
6975
$PYTHON -m pip install dist/*.whl
7076
7177
# Install "jax" at head

.github/workflows/jax-array-api.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
with:
2929
repository: data-apis/array-api-tests
3030
# TODO(jakevdp) update this to a stable release/tag when available.
31-
ref: 'ad81cf6c3721d9dbeb168bdab49c962b6b38c0d5' # Latest commit as of 2024-11-20
31+
ref: '1572b129c6682211abfe139e112592226c361a6c' # Latest commit as of 2024-12-04
3232
submodules: 'true'
3333
path: 'array-api-tests'
3434
- name: Set up Python ${{ matrix.python-version }}
@@ -38,11 +38,11 @@ jobs:
3838
- name: Install dependencies
3939
run: |
4040
python -m pip install .[ci]
41-
python -m pip install -r array-api-tests/requirements.txt
41+
python -m pip install pytest-xdist -r array-api-tests/requirements.txt
4242
- name: Run the test suite
4343
env:
4444
ARRAY_API_TESTS_MODULE: jax.numpy
4545
JAX_ENABLE_X64: 'true'
4646
run: |
4747
cd ${GITHUB_WORKSPACE}/array-api-tests
48-
pytest array_api_tests --max-examples=5 --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt
48+
pytest -n auto array_api_tests --derandomize --disable-deadline --skips-file ${GITHUB_WORKSPACE}/tests/array_api_skips.txt

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ repos:
3636
- id: mypy
3737
files: (jax/|tests/typing_test\.py)
3838
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
39-
additional_dependencies: [types-requests==2.31.0, jaxlib]
39+
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0]
4040
args: [--config=pyproject.toml]
4141

4242
- repo: https://github.com/mwouts/jupytext

CHANGELOG.md

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,20 @@ Remember to align the itemized text with the first line of an item within a list
1010
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
1111
-->
1212

13-
## jax 0.4.36
13+
## jax 0.4.38
14+
15+
## jax 0.4.37 (Dec 9, 2024)
16+
17+
This is a patch release of jax 0.4.36. Only "jax" was released at this version.
18+
19+
* Bug fixes
20+
* Fixed a bug where `jit` would error if an argument was named `f` (#25329).
21+
* Fix a bug that will throw `index out of range` error in
22+
{func}`jax.lax.while_loop` if the user register pytree node class with
23+
different aux data for the flatten and flatten_with_path.
24+
* Pinned a new libtpu release (0.0.6) that fixes a compiler bug on TPU v6e.
25+
26+
## jax 0.4.36 (Dec 5, 2024)
1427

1528
* Breaking Changes
1629
* This release lands "stackless", an internal change to JAX's tracing
@@ -53,6 +66,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
5366
use `uses_global_constants`.
5467
* the `lowering_platforms` kwarg for {func}`jax.export.export`: use
5568
`platforms` instead.
69+
* The kwargs `symbolic_scope` and `symbolic_constraints` from
70+
{func}`jax.export.symbolic_args_specs` have been removed. They were
71+
deprecated in June 2024. Use `scope` and `constraints` instead.
5672
* Hashing of tracers, which has been deprecated since version 0.4.30, now
5773
results in a `TypeError`.
5874
* Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
@@ -67,6 +83,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
6783
return NaN for negative integer inputs, to match the behavior of SciPy from
6884
https://github.com/scipy/scipy/pull/21827.
6985
* `jax.clear_backends` was removed after being deprecated in v0.4.26.
86+
* We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
87+
call that we guarantee export stability. This is because this custom call
88+
relies on Triton IR, which is not guaranteed to be stable. If you need
89+
to export code that uses this custom call, you can use the `disabled_checks`
90+
parameter. See more details in the [documentation](https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees-for-custom-calls).
7091

7192
* New Features
7293
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
@@ -79,6 +100,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
79100
* {func}`jax.lax.linalg.eig` and the related `jax.numpy` functions
80101
({func}`jax.numpy.linalg.eig` and {func}`jax.numpy.linalg.eigvals`) are now
81102
supported on GPU. See {jax-issue}`#24663` for more details.
103+
* Added two new configuration flags, `jax_exec_time_optimization_effort` and `jax_memory_fitting_effort`, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
82104

83105
* Bug fixes
84106
* Fixed a bug where the GPU implementations of LU and QR decomposition would

build/build.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ def add_global_arguments(parser: argparse.ArgumentParser):
123123
help="Produce verbose output for debugging.",
124124
)
125125

126+
parser.add_argument(
127+
"--detailed_timestamped_log",
128+
action="store_true",
129+
help="""
130+
Enable detailed logging of the Bazel command with timestamps. The logs
131+
will be stored and can be accessed as artifacts.
132+
""",
133+
)
134+
126135

127136
def add_artifact_subcommand_arguments(parser: argparse.ArgumentParser):
128137
"""Adds all the arguments that applies to the artifact subcommands."""
@@ -399,7 +408,7 @@ async def main():
399408
else:
400409
requirements_command.append("//build:requirements.update")
401410

402-
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run)
411+
result = await executor.run(requirements_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
403412
if result.return_code != 0:
404413
raise RuntimeError(f"Command failed with return code {result.return_code}")
405414
else:
@@ -597,7 +606,7 @@ async def main():
597606

598607
wheel_build_command.append(f"--jaxlib_git_hash={git_hash}")
599608

600-
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run)
609+
result = await executor.run(wheel_build_command.get_command_as_string(), args.dry_run, args.detailed_timestamped_log)
601610
# Exit with error if any wheel build fails.
602611
if result.return_code != 0:
603612
raise RuntimeError(f"Command failed with return code {result.return_code}")

build/tools/command.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, environment: Dict[str, str] = None):
7575
"""
7676
self.environment = environment or dict(os.environ)
7777

78-
async def run(self, cmd: str, dry_run: bool = False) -> CommandResult:
78+
async def run(self, cmd: str, dry_run: bool = False, detailed_timestamped_log: bool = False) -> CommandResult:
7979
"""
8080
Executes a subprocess command.
8181
@@ -96,14 +96,15 @@ async def run(self, cmd: str, dry_run: bool = False) -> CommandResult:
9696

9797
process = await asyncio.create_subprocess_shell(
9898
cmd,
99-
stdout=asyncio.subprocess.PIPE,
100-
stderr=asyncio.subprocess.PIPE,
99+
stdout=asyncio.subprocess.PIPE if detailed_timestamped_log else None,
100+
stderr=asyncio.subprocess.PIPE if detailed_timestamped_log else None,
101101
env=self.environment,
102102
)
103103

104-
await asyncio.gather(
105-
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
106-
)
104+
if detailed_timestamped_log:
105+
await asyncio.gather(
106+
_process_log_stream(process.stdout, result), _process_log_stream(process.stderr, result)
107+
)
107108

108109
result.return_code = await process.wait()
109110
result.end_time = datetime.datetime.now()

ci/build_artifacts.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ if [[ "${allowed_artifacts[@]}" =~ "${artifact}" ]]; then
6969
fi
7070

7171
# Build the artifact.
72-
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose
72+
python build/build.py build --wheels="$artifact" --bazel_options=--config="$bazelrc_config" --python_version=$JAXCI_HERMETIC_PYTHON_VERSION --verbose --detailed_timestamped_log
7373

7474
# If building `jaxlib` or `jax-cuda-plugin` or `jax-cuda-pjrt` for Linux, we
7575
# run `auditwheel show` to verify manylinux compliance.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""
16+
Converts MSYS Linux-like paths stored in env variables to Windows paths.
17+
18+
This is necessary on Windows, because some applications do not understand/handle
19+
Linux-like paths MSYS uses, for example, Bazel.
20+
"""
21+
import argparse
22+
import os
23+
import subprocess
24+
25+
def msys_to_windows_path(msys_path):
26+
"""Converts an MSYS path to a Windows path using cygpath.
27+
28+
Args:
29+
msys_path: The MSYS path to convert.
30+
31+
Returns:
32+
The corresponding Windows path.
33+
"""
34+
try:
35+
# Use cygpath with the -w flag to convert to Windows format
36+
process = subprocess.run(['cygpath', '-w', msys_path], capture_output=True, text=True, check=True)
37+
windows_path = process.stdout.strip()
38+
return windows_path
39+
except FileNotFoundError:
40+
print("Error: cygpath not found. Make sure it's in your PATH.")
41+
return None
42+
except subprocess.CalledProcessError as e:
43+
print(f"Error converting path: {e}")
44+
return None
45+
46+
def should_convert(var: str,
47+
convert: list[str] | None):
48+
"""Check the variable name against convert list"""
49+
if var in convert:
50+
return True
51+
else:
52+
return False
53+
54+
def main(parsed_args: argparse.Namespace):
55+
converted_paths = {}
56+
57+
for var, value in os.environ.items():
58+
if not value or not should_convert(var,
59+
parsed_args.convert):
60+
continue
61+
converted_path = msys_to_windows_path(value)
62+
converted_paths[var] = converted_path
63+
64+
var_str = '\n'.join(f'export {k}="{v}"'
65+
for k, v in converted_paths.items())
66+
# The string can then be piped into `source`, to re-set the
67+
# 'converted' variables.
68+
print(var_str)
69+
70+
71+
if __name__ == '__main__':
72+
parser = argparse.ArgumentParser(description=(
73+
'Convert MSYS paths in environment variables to Windows paths.'))
74+
parser.add_argument('--convert',
75+
nargs='+',
76+
required=True,
77+
help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2')
78+
args = parser.parse_args()
79+
80+
main(args)

0 commit comments

Comments
 (0)