Skip to content

Commit a10a30e

Browse files
Merge remote-tracking branch 'upstream/main' into rocm7.1_internal_testing_IFU_2025-09-23
# Conflicts: # .ci/aarch64_linux/aarch64_ci_build.sh # .ci/aarch64_linux/aarch64_wheel_ci_build.py # .ci/docker/build.sh # .ci/docker/ci_commit_pins/huggingface-requirements.txt # .ci/docker/ci_commit_pins/triton.txt # .ci/docker/common/install_rocm.sh # .ci/docker/requirements-ci.txt # .ci/docker/requirements-docs.txt # .ci/libtorch/build.sh # .ci/lumen_cli/cli/lib/core/vllm/lib.py # .ci/lumen_cli/cli/lib/core/vllm/vllm_build.py # .ci/lumen_cli/cli/lib/core/vllm/vllm_test.py # .ci/wheel/build_wheel.sh # .github/ci_commit_pins/audio.txt # .github/ci_commit_pins/vllm.txt # .github/ci_commit_pins/xla.txt # .github/ci_configs/vllm/Dockerfile.tmp_vllm # .github/scripts/generate_binary_build_matrix.py # .github/templates/macos_binary_build_workflow.yml.j2 # .github/workflows/build-vllm-wheel.yml # .github/workflows/docker-builds.yml # .github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml # .github/workflows/generated-linux-binary-manywheel-main.yml # .github/workflows/generated-linux-binary-manywheel-nightly.yml # .github/workflows/generated-linux-binary-manywheel-rocm-main.yml # .github/workflows/generated-macos-arm64-binary-libtorch-release-nightly.yml # .github/workflows/generated-macos-arm64-binary-wheel-nightly.yml # .github/workflows/inductor-nightly.yml # .github/workflows/inductor-perf-test-nightly-x86-zen.yml # .github/workflows/inductor-perf-test-nightly-x86.yml # .github/workflows/inductor-periodic.yml # .github/workflows/inductor-unittest.yml # .github/workflows/inductor.yml # .github/workflows/operator_benchmark.yml # .github/workflows/pull.yml # .github/workflows/trunk.yml # .github/workflows/vllm.yml # aten/src/ATen/CMakeLists.txt # aten/src/ATen/DLConvertor.cpp # aten/src/ATen/cuda/CUDABlas.cpp # aten/src/ATen/native/CPUBlas.cpp # aten/src/ATen/native/LinearAlgebra.cpp # aten/src/ATen/native/Normalization.cpp # aten/src/ATen/native/cuda/Blas.cpp # aten/src/ATen/native/cuda/int8mm.cu # aten/src/ATen/native/cudnn/MHA.cpp # aten/src/ATen/native/miopen/BatchNorm_miopen.cpp # aten/src/ATen/native/miopen/Conv_miopen.cpp # aten/src/ATen/native/mps/operations/GridSampler.mm # aten/src/ATen/native/native_functions.yaml # aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm # aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h # benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_training.csv # benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_training.csv # benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_training.csv # benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_training.csv # benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_training.csv # benchmarks/dynamo/ci_expected_accuracy/rocm/aot_eager_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/rocm/dynamic_aot_eager_torchbench_inference.csv # benchmarks/dynamo/ci_expected_accuracy/rocm/dynamo_eager_torchbench_inference.csv # benchmarks/dynamo/pr_time_benchmarks/expected_results.csv # benchmarks/operator_benchmark/benchmark_core.py # build_variables.bzl # c10/cuda/CUDAFunctions.cpp # cmake/Codegen.cmake # cmake/External/aotriton.cmake # docs/source/accelerator/index.md # docs/source/accelerator/operators.md # functorch/dim/__init__.py # functorch/dim/wrap_type.py # requirements-build.txt # requirements.txt # test/cpp/nativert/CMakeLists.txt # test/cpp/nativert/test_triton_kernel_manager_registration.cpp # test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp # test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py # test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py # test/cpp_extensions/open_registration_extension/torch_openreg/README.md # test/cpp_extensions/open_registration_extension/torch_openreg/setup.py # test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/README.md # test/cpp_extensions/open_registration_extension/torch_openreg/third_party/openreg/example/example.cpp # test/cpp_extensions/open_registration_extension/torch_openreg/torch_openreg/__init__.py # test/distributed/_composable/fsdp/test_fully_shard_training.py # test/distributed/_composable/test_composability/test_2d_composability.py # test/distributed/fsdp/test_fsdp_comm_hooks.py # test/distributed/tensor/parallel/test_tp_examples.py # test/distributed/tensor/test_attention.py # test/distributed/tensor/test_dtensor_compile.py # test/distributed/tensor/test_dtensor_ops.py # test/distributed/tensor/test_op_schema.py # test/distributed/test_inductor_collectives.py # test/distributed/test_nvshmem.py # test/distributed/test_nvshmem_triton.py # test/distributed/test_symmetric_memory.py # test/dynamo/test_activation_checkpointing.py # test/dynamo/test_aot_compile.py # test/dynamo/test_callback.py # test/dynamo/test_error_messages.py # test/dynamo/test_guard_serialization.py # test/dynamo/test_misc.py # test/dynamo/test_package.py # test/dynamo/test_structured_trace.py # test/export/test_export.py # test/export/test_export_opinfo.py # test/export/test_passes.py # test/export/test_serialize.py # test/functorch/test_control_flow.py # test/inductor/test_aot_inductor.py # test/inductor/test_aot_inductor_package.py # test/inductor/test_flex_attention.py # test/inductor/test_fxir_backend.py # test/inductor/test_loop_ordering.py # test/inductor/test_max_autotune.py # test/inductor/test_torchinductor.py # test/nn/test_convolution.py # test/nn/test_pooling.py # test/run_test.py # test/slow_tests.json # test/test_binary_ufuncs.py # test/test_dynamic_shapes.py # test/test_matmul_cuda.py # test/test_nestedtensor.py # test/test_nn.py # test/test_openreg.py # third_party/xpu.txt # tools/flight_recorder/components/config_manager.py # tools/pyi/gen_pyi.py # torch/_C/_dynamo/guards.pyi # torch/_dynamo/aot_compile.py # torch/_dynamo/convert_frame.py # torch/_dynamo/functional_export.py # torch/_dynamo/graph_break_registry.json # torch/_dynamo/guards.py # torch/_dynamo/output_graph.py # torch/_dynamo/package.py # torch/_dynamo/symbolic_convert.py # torch/_dynamo/variables/higher_order_ops.py # torch/_dynamo/variables/lists.py # torch/_dynamo/variables/optimizer.py # torch/_export/serde/serialize.py # torch/_export/wrappers.py # torch/_functorch/_aot_autograd/autograd_cache.py # torch/_higher_order_ops/__init__.py # torch/_higher_order_ops/associative_scan.py # torch/_higher_order_ops/flex_attention.py # torch/_higher_order_ops/triton_kernel_wrap.py # torch/_inductor/choices.py # torch/_inductor/codegen/cpp.py # torch/_inductor/codegen/cpp_micro_gemm.py # torch/_inductor/codegen/cpp_wrapper_cpu.py # torch/_inductor/codegen/triton.py # torch/_inductor/codegen/wrapper_fxir.py # torch/_inductor/config.py # torch/_inductor/cpp_builder.py # torch/_inductor/decomposition.py # torch/_inductor/kernel/bmm.py # torch/_inductor/kernel/flex/flex_attention.py # torch/_inductor/kernel/flex/templates/flex_attention.py.jinja # torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja # torch/_inductor/kernel/flex/templates/flex_decode.py.jinja # torch/_inductor/kernel/flex/templates/utilities.py.jinja # torch/_inductor/kernel/mm.py # torch/_inductor/kernel/mm_plus_mm.py # torch/_inductor/kernel_template_choice.py # torch/_inductor/memory.py # torch/_inductor/runtime/triton_heuristics.py # torch/_inductor/scheduler.py # torch/_inductor/select_algorithm.py # torch/_inductor/template_heuristics/base.py # torch/_inductor/template_heuristics/triton.py # torch/_inductor/utils.py # torch/_meta_registrations.py # torch/_prims_common/__init__.py # torch/csrc/Module.cpp # torch/csrc/autograd/python_variable.cpp # torch/csrc/autograd/python_variable_indexing.cpp # torch/csrc/distributed/c10d/FlightRecorder.cpp # torch/csrc/distributed/c10d/ProcessGroupGloo.hpp # torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu # torch/csrc/inductor/aoti_runtime/utils.h # torch/csrc/stable/accelerator.h # torch/csrc/stable/ops.h # torch/csrc/utils/generated_serialization_types.h # torch/csrc/utils/tensor_numpy.cpp # torch/distributed/_symmetric_memory/_nvshmem_triton.py # torch/distributed/device_mesh.py # torch/distributed/pipelining/_schedule_visualizer.py # torch/distributed/tensor/_api.py # torch/distributed/tensor/_dispatch.py # torch/distributed/tensor/_op_schema.py # torch/distributed/tensor/_random.py # torch/distributed/tensor/_sharding_prop.py # torch/export/_trace.py # torch/export/_unlift.py # torch/export/exported_program.py # torch/fx/experimental/proxy_tensor.py # torch/nativert/executor/triton/CpuTritonKernelManager.cpp # torch/nativert/executor/triton/CudaTritonKernelManager.cpp # torch/nativert/executor/triton/TritonKernelManager.h # torch/nativert/kernels/KernelHandlerRegistry.cpp # torch/nativert/kernels/TritonKernel.cpp # torch/nested/_internal/ops.py # torch/onnx/__init__.py # torch/overrides.py # torch/testing/_internal/common_cuda.py # torch/testing/_internal/common_distributed.py # torch/testing/_internal/common_quantization.py # torch/testing/_internal/common_utils.py # torch/testing/_internal/distributed/_tensor/common_dtensor.py # torch/testing/_internal/distributed/fake_pg.py # torch/testing/_internal/hop_db.py # torch/utils/_python_dispatch.py # torch/utils/data/datapipes/iter/combinatorics.py
2 parents 7ea3967 + aff76c0 commit a10a30e

File tree

984 files changed

+39283
-6622
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

984 files changed

+39283
-6622
lines changed

.ci/aarch64_linux/aarch64_ci_build.sh

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@ GPU_ARCH_VERSION=${GPU_ARCH_VERSION:-}
55

66
# Set CUDA architecture lists to match x86 build_cuda.sh
77
if [[ "$GPU_ARCH_VERSION" == *"12.6"* ]]; then
8+
<<<<<<< HEAD
89
export TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;8.0;9.0"
910
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
1011
export TORCH_CUDA_ARCH_LIST="7.0;8.0;9.0;10.0;12.0"
12+
=======
13+
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
14+
elif [[ "$GPU_ARCH_VERSION" == *"12.8"* ]]; then
15+
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;12.0"
16+
>>>>>>> upstream/main
1117
elif [[ "$GPU_ARCH_VERSION" == *"13.0"* ]]; then
1218
export TORCH_CUDA_ARCH_LIST="8.0;9.0;10.0;11.0;12.0+PTX"
1319
fi
@@ -31,8 +37,7 @@ pip install -r /pytorch/requirements.txt
3137
pip install auditwheel==6.2.0 wheel
3238
if [ "$DESIRED_CUDA" = "cpu" ]; then
3339
echo "BASE_CUDA_VERSION is not set. Building cpu wheel."
34-
#USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files
35-
USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
40+
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn
3641
else
3742
echo "BASE_CUDA_VERSION is set to: $DESIRED_CUDA"
3843
export USE_SYSTEM_NCCL=1
@@ -42,13 +47,20 @@ else
4247
echo "Bundling CUDA libraries with wheel for aarch64."
4348
else
4449
echo "Using nvidia libs from pypi for aarch64."
50+
<<<<<<< HEAD
4551
# Fix platform constraints in PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64
4652
# Replace 'platform_machine == "x86_64"' with 'platform_machine == "aarch64"'
4753
export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS//platform_machine == \'x86_64\'/platform_machine == \'aarch64\'}"
54+
=======
55+
>>>>>>> upstream/main
4856
echo "Updated PYTORCH_EXTRA_INSTALL_REQUIREMENTS for aarch64: $PYTORCH_EXTRA_INSTALL_REQUIREMENTS"
4957
export USE_NVIDIA_PYPI_LIBS=1
5058
fi
5159

60+
<<<<<<< HEAD
5261
#USE_PRIORITIZED_TEXT_FOR_LD for enable linker script optimization https://github.com/pytorch/pytorch/pull/121975/files
5362
USE_PRIORITIZED_TEXT_FOR_LD=1 python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
63+
=======
64+
python /pytorch/.ci/aarch64_linux/aarch64_wheel_ci_build.py --enable-mkldnn --enable-cuda
65+
>>>>>>> upstream/main
5466
fi

.ci/aarch64_linux/aarch64_wheel_ci_build.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None:
138138
folder = os.path.dirname(wheel_path)
139139
os.mkdir(f"{folder}/tmp")
140140
os.system(f"unzip {wheel_path} -d {folder}/tmp")
141+
<<<<<<< HEAD
142+
=======
143+
# Delete original wheel since it will be repackaged
144+
os.system(f"rm {wheel_path}")
145+
>>>>>>> upstream/main
141146

142147
# Check if we should use PyPI NVIDIA libraries or bundle system libraries
143148
use_nvidia_pypi_libs = os.getenv("USE_NVIDIA_PYPI_LIBS", "0") == "1"
@@ -211,7 +216,12 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None:
211216
]
212217

213218
# CUDA version-specific libraries
219+
<<<<<<< HEAD
214220
if "130" in desired_cuda:
221+
=======
222+
if "13" in desired_cuda:
223+
minor_version = desired_cuda[-1]
224+
>>>>>>> upstream/main
215225
version_specific_libs = [
216226
"/usr/local/cuda/extras/CUPTI/lib64/libcupti.so.13",
217227
"/usr/local/cuda/lib64/libcublas.so.13",
@@ -221,7 +231,11 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None:
221231
"/usr/local/cuda/lib64/libcusolver.so.12",
222232
"/usr/local/cuda/lib64/libnvJitLink.so.13",
223233
"/usr/local/cuda/lib64/libnvrtc.so.13",
234+
<<<<<<< HEAD
224235
"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.0",
236+
=======
237+
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.13.{minor_version}",
238+
>>>>>>> upstream/main
225239
]
226240
elif "12" in desired_cuda:
227241
# Get the last character for libnvrtc-builtins version (e.g., "129" -> "9")
@@ -237,6 +251,11 @@ def package_cuda_wheel(wheel_path, desired_cuda) -> None:
237251
"/usr/local/cuda/lib64/libnvrtc.so.12",
238252
f"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.{minor_version}",
239253
]
254+
<<<<<<< HEAD
255+
=======
256+
else:
257+
raise ValueError(f"Unsupported CUDA version: {desired_cuda}.")
258+
>>>>>>> upstream/main
240259

241260
# Combine all libraries
242261
libs_to_copy = common_libs + version_specific_libs
@@ -275,14 +294,7 @@ def complete_wheel(folder: str) -> str:
275294
f"/{folder}/dist/{repaired_wheel_name}",
276295
)
277296
else:
278-
repaired_wheel_name = wheel_name.replace(
279-
"linux_aarch64", "manylinux_2_28_aarch64"
280-
)
281-
print(f"Renaming {wheel_name} wheel to {repaired_wheel_name}")
282-
os.rename(
283-
f"/{folder}/dist/{wheel_name}",
284-
f"/{folder}/dist/{repaired_wheel_name}",
285-
)
297+
repaired_wheel_name = list_dir(f"/{folder}/dist")[0]
286298

287299
print(f"Copying {repaired_wheel_name} to artifacts")
288300
shutil.copy2(
@@ -319,7 +331,7 @@ def parse_arguments():
319331
).decode()
320332

321333
print("Building PyTorch wheel")
322-
build_vars = "CMAKE_SHARED_LINKER_FLAGS=-Wl,-z,max-page-size=0x10000 "
334+
build_vars = ""
323335
# MAX_JOB=5 is not required for CPU backend (see commit 465d98b)
324336
if enable_cuda:
325337
build_vars += "MAX_JOBS=5 "

.ci/docker/build.sh

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ case "$tag" in
214214
TRITON=yes
215215
;;
216216
pytorch-linux-jammy-py3-gcc11-inductor-benchmarks)
217+
<<<<<<< HEAD
217218
# TODO (huydhn): Upgrade this to Python >= 3.10
218219
ANACONDA_PYTHON_VERSION=3.9
220+
=======
221+
ANACONDA_PYTHON_VERSION=3.10
222+
>>>>>>> upstream/main
219223
GCC_VERSION=11
220224
VISION=yes
221225
KATEX=yes
@@ -263,13 +267,10 @@ case "$tag" in
263267
TRITON_CPU=yes
264268
;;
265269
pytorch-linux-jammy-linter)
266-
# TODO: Use 3.9 here because of this issue https://github.com/python/mypy/issues/13627.
267-
# We will need to update mypy version eventually, but that's for another day. The task
268-
# would be to upgrade mypy to 1.0.0 with Python 3.11
269-
PYTHON_VERSION=3.9
270+
PYTHON_VERSION=3.10
270271
;;
271-
pytorch-linux-jammy-cuda12.8-cudnn9-py3.9-linter)
272-
PYTHON_VERSION=3.9
272+
pytorch-linux-jammy-cuda12.8-cudnn9-py3.10-linter)
273+
PYTHON_VERSION=3.10
273274
CUDA_VERSION=12.8.1
274275
;;
275276
pytorch-linux-jammy-aarch64-py3.10-gcc11)

.ci/docker/centos-rocm/Dockerfile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,13 @@ ENV INSTALLED_VISION ${VISION}
5959

6060
# Install rocm
6161
ARG ROCM_VERSION
62+
RUN mkdir ci_commit_pins
63+
COPY ./common/common_utils.sh common_utils.sh
64+
COPY ./ci_commit_pins/rocm-composable-kernel.txt ci_commit_pins/rocm-composable-kernel.txt
6265
COPY ./common/install_rocm.sh install_rocm.sh
6366
RUN bash ./install_rocm.sh
64-
RUN rm install_rocm.sh
67+
RUN rm install_rocm.sh common_utils.sh
68+
RUN rm -r ci_commit_pins
6569
COPY ./common/install_rocm_magma.sh install_rocm_magma.sh
6670
RUN bash ./install_rocm_magma.sh ${ROCM_VERSION}
6771
RUN rm install_rocm_magma.sh
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
56392aa978594cc155fa8af48cd949f5b5f1823a
1+
e0dda9059d082537cee36be6c5e4fe3b18c880c0
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1+
<<<<<<< HEAD
12
transformers==4.54.0
3+
=======
4+
transformers==4.56.0
5+
>>>>>>> upstream/main
26
soxr==0.5.0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
7fe50dc3da2069d6645d9deb8c017a876472a977
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1+
<<<<<<< HEAD
12
6193b30becb1ac7be704cf87b8cb9bf13e7f9689
3+
=======
4+
bbb06c0334a6772b92d24bde54956e675c8c6604
5+
>>>>>>> upstream/main

.ci/docker/common/install_executorch.sh

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,27 @@ install_pip_dependencies() {
4242
# A workaround, ExecuTorch has moved to numpy 2.0 which is not compatible with the current
4343
# numba and scipy version used in PyTorch CI
4444
conda_run pip uninstall -y numba scipy
45+
# Yaspin is needed for running CI test (get_benchmark_analysis_data.py)
46+
pip_install yaspin==3.1.0
4547

4648
popd
4749
}
4850

4951
setup_executorch() {
50-
pushd executorch
51-
5252
export PYTHON_EXECUTABLE=python
53-
export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON"
53+
export CMAKE_ARGS="-DEXECUTORCH_BUILD_PYBIND=ON -DEXECUTORCH_BUILD_XNNPACK=ON -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON -DEXECUTORCH_BUILD_TESTS=ON"
5454

5555
as_jenkins .ci/scripts/setup-linux.sh --build-tool cmake || true
56-
popd
5756
}
5857

59-
clone_executorch
60-
install_buck2
61-
install_conda_dependencies
62-
install_pip_dependencies
63-
setup_executorch
58+
if [ $# -eq 0 ]; then
59+
clone_executorch
60+
install_buck2
61+
install_conda_dependencies
62+
install_pip_dependencies
63+
pushd executorch
64+
setup_executorch
65+
popd
66+
else
67+
"$@"
68+
fi

.ci/docker/common/install_rocm.sh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
set -ex
44

5+
# for pip_install function
6+
source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
7+
8+
ROCM_COMPOSABLE_KERNEL_VERSION="$(cat $(dirname $0)/../ci_commit_pins/rocm-composable-kernel.txt)"
9+
510
ver() {
611
printf "%3d%03d%03d%03d" $(echo "$1" | tr '.' ' ');
712
}
@@ -109,8 +114,12 @@ EOF
109114
rm -rf HIP clr
110115
fi
111116

117+
<<<<<<< HEAD
112118
# temporary hipblasLT dependency install
113119
apt install libmsgpackc2
120+
=======
121+
pip_install "git+https://github.com/rocm/composable_kernel@$ROCM_COMPOSABLE_KERNEL_VERSION"
122+
>>>>>>> upstream/main
114123

115124
# Cleanup
116125
apt-get autoclean && apt-get clean
@@ -195,6 +204,8 @@ install_centos() {
195204
sqlite3 $kdb "PRAGMA journal_mode=off; PRAGMA VACUUM;"
196205
done
197206

207+
pip_install "git+https://github.com/rocm/composable_kernel@$ROCM_COMPOSABLE_KERNEL_VERSION"
208+
198209
# Cleanup
199210
yum clean all
200211
rm -rf /var/cache/yum

0 commit comments

Comments
 (0)